]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add bindings for PQ tracing functions
authorDenis Laxalde <denis.laxalde@dalibo.com>
Fri, 3 Dec 2021 07:05:16 +0000 (08:05 +0100)
committerDenis Laxalde <denis@laxalde.org>
Tue, 7 Dec 2021 21:36:53 +0000 (22:36 +0100)
Since we cannot pass a file descriptor as a FILE value, as expected by
PQtrace(), PGconn.trace() method accepts a 'fileno: int' value. This is
then used to build an stdio's FILE value through fdopen(). The latter
also needs a binding in ctypes. This only works on Linux platform.

In _pq_ctypes.pyi, fdopen() and PQtrace() are not autogenerated because
of the needed '# type: ignore' (similar to existing ones).

PQsetTraceFlags() is new from libpq 14, so we declare it conditionally.

psycopg/psycopg/pq/__init__.py
psycopg/psycopg/pq/_enums.py
psycopg/psycopg/pq/_pq_ctypes.py
psycopg/psycopg/pq/_pq_ctypes.pyi
psycopg/psycopg/pq/abc.py
psycopg/psycopg/pq/pq_ctypes.py
psycopg_c/psycopg_c/pq/libpq.pxd
psycopg_c/psycopg_c/pq/pgconn.pyx
tests/pq/test_pgconn.py

index bc1c70dfe9a3218bfa21ee3a5c63a5b4bd576025..c30ead5e730ca28dc666d29f2e73690c51535744 100644 (file)
@@ -16,7 +16,7 @@ from typing import Callable, List, Optional, Type
 from . import abc
 from .misc import ConninfoOption, PGnotify, PGresAttDesc
 from .misc import error_message
-from ._enums import ConnStatus, DiagnosticField, ExecStatus, Format
+from ._enums import ConnStatus, DiagnosticField, ExecStatus, Format, Trace
 from ._enums import Ping, PipelineStatus, PollingStatus, TransactionStatus
 
 logger = logging.getLogger(__name__)
@@ -122,6 +122,7 @@ __all__ = (
     "Ping",
     "DiagnosticField",
     "Format",
+    "Trace",
     "PGconn",
     "PGnotify",
     "Conninfo",
index 8eca77b453fa127187691f596030a10ee9f3b659..88e3113ca4a6741dd6825260ac39c7b31579a2f3 100644 (file)
@@ -4,7 +4,7 @@ libpq enum definitions for psycopg
 
 # Copyright (C) 2020-2021 The Psycopg Team
 
-from enum import IntEnum, auto
+from enum import IntEnum, IntFlag, auto
 
 
 class ConnStatus(IntEnum):
@@ -232,3 +232,17 @@ class Format(IntEnum):
     """Text parameter."""
     BINARY = 1
     """Binary parameter."""
+
+
+class Trace(IntFlag):
+    """
+    Enum to control tracing of the client/server communication.
+    """
+
+    __module__ = "psycopg.pq"
+
+    SUPPRESS_TIMESTAMPS = 1
+    """Do not include timestamps in messages."""
+
+    REGRESS_MODE = 2
+    """Redact some fields, e.g. OIDs, from messages."""
index d7f326ed75b659a1b8966ee67b0c3b7a2859ca93..24cfd14bfa2d92e8c99b4c6c139a669520a454f3 100644 (file)
@@ -24,6 +24,23 @@ if not libname:
 
 pq = ctypes.cdll.LoadLibrary(libname)
 
+
+class FILE(Structure):
+    pass
+
+
+FILE_ptr = POINTER(FILE)
+
+if sys.platform == "linux":
+    libcname = ctypes.util.find_library("c")
+    assert libcname
+    libc = ctypes.cdll.LoadLibrary(libcname)
+
+    fdopen = libc.fdopen
+    fdopen.argtypes = (c_int, c_char_p)
+    fdopen.restype = FILE_ptr
+
+
 # Get the libpq version to define what functions are available.
 
 PQlibVersion = pq.PQlibVersion
@@ -551,6 +568,34 @@ PQgetCopyData.argtypes = [PGconn_ptr, POINTER(c_char_p), c_int]
 PQgetCopyData.restype = c_int
 
 
+# 33.10. Control Functions
+
+PQtrace = pq.PQtrace
+PQtrace.argtypes = [PGconn_ptr, FILE_ptr]
+PQtrace.restype = None
+
+_PQsetTraceFlags = None
+
+if libpq_version >= 140000:
+    _PQsetTraceFlags = pq.PQsetTraceFlags
+    _PQsetTraceFlags.argtypes = [PGconn_ptr, c_int]
+    _PQsetTraceFlags.restype = None
+
+
+def PQsetTraceFlags(pgconn: PGconn_struct, flags: int) -> None:
+    if not _PQsetTraceFlags:
+        raise NotSupportedError(
+            f"PQsetTraceFlags requires libpq from PostgreSQL 14,"
+            f" {libpq_version} available instead"
+        )
+
+    _PQsetTraceFlags(pgconn, flags)
+
+
+PQuntrace = pq.PQuntrace
+PQuntrace.argtypes = [PGconn_ptr]
+PQuntrace.restype = None
+
 # 33.11. Miscellaneous Functions
 
 PQfreemem = pq.PQfreemem
@@ -715,6 +760,7 @@ def generate_stub() -> None:
             "LP_c_int",
             "LP_c_uint",
             "LP_c_ulong",
+            "LP_FILE",
         ):
             return f"pointer[{t.__name__[3:]}]"
 
index d451832e5d3edb76596ffff991c6b169260d6ba8..e19f8074b8fc0d7a38cd3dbcb60818affdb7288a 100644 (file)
@@ -8,6 +8,10 @@ from typing import Any, Callable, Optional, Sequence
 from ctypes import Array, pointer
 from ctypes import c_char, c_char_p, c_int, c_ubyte, c_uint, c_ulong
 
+class FILE: ...
+
+def fdopen(fd: int, mode: bytes) -> pointer[FILE]: ...  # type: ignore[type-var]
+
 Oid = c_uint
 
 class PGconn_struct: ...
@@ -110,6 +114,10 @@ def PQsetResultAttrs(
     arg2: int,
     arg3: Array[PGresAttDesc_struct],  # type: ignore
 ) -> int: ...
+def PQtrace(
+    arg1: Optional[PGconn_struct],
+    arg2: pointer[FILE],  # type: ignore[type-var]
+) -> None: ...
 def PQencryptPasswordConn(
     arg1: Optional[PGconn_struct],
     arg2: bytes,
@@ -197,6 +205,8 @@ def PQsetSingleRowMode(arg1: Optional[PGconn_struct]) -> int: ...
 def PQgetCancel(arg1: Optional[PGconn_struct]) -> PGcancel_struct: ...
 def PQfreeCancel(arg1: Optional[PGcancel_struct]) -> None: ...
 def PQputCopyData(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> int: ...
+def PQsetTraceFlags(arg1: Optional[PGconn_struct], arg2: int) -> None: ...
+def PQuntrace(arg1: Optional[PGconn_struct]) -> None: ...
 def PQfreemem(arg1: Any) -> None: ...
 def _PQencryptPasswordConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: bytes, arg4: bytes) -> Optional[bytes]: ...
 def PQmakeEmptyPGresult(arg1: Optional[PGconn_struct], arg2: int) -> PGresult_struct: ...
index 63bb30a71f41a7771f6aab40d614869471aa9c86..a183185e999addc5bfbbe740a34d911d882d1111 100644 (file)
@@ -7,7 +7,7 @@ Protocol objects to represent objects exposed by different pq implementations.
 from typing import Any, Callable, List, Optional, Sequence, Tuple
 from typing import Union, TYPE_CHECKING
 
-from ._enums import Format
+from ._enums import Format, Trace
 from .._compat import Protocol
 
 if TYPE_CHECKING:
@@ -234,6 +234,15 @@ class PGconn(Protocol):
     def get_copy_data(self, async_: int) -> Tuple[int, memoryview]:
         ...
 
+    def trace(self, fileno: int) -> None:
+        ...
+
+    def set_trace_flags(self, flags: Trace) -> None:
+        ...
+
+    def untrace(self) -> None:
+        ...
+
     def encrypt_password(
         self, passwd: bytes, user: bytes, algorithm: Optional[bytes] = None
     ) -> bytes:
index 42b86a5530bac461ef47aa95c1b9558e2db77d9a..f31ce1d01127746960edb4efe8a7e8831f094345 100644 (file)
@@ -10,6 +10,7 @@ implementation.
 
 import os
 import logging
+import sys
 from weakref import ref
 from functools import partial
 
@@ -22,7 +23,7 @@ from .. import errors as e
 from . import _pq_ctypes as impl
 from .misc import PGnotify, ConninfoOption, PGresAttDesc
 from .misc import error_message, connection_summary
-from ._enums import Format, ExecStatus
+from ._enums import Format, ExecStatus, Trace
 
 if TYPE_CHECKING:
     from . import abc
@@ -608,6 +609,18 @@ class PGconn:
         else:
             return nbytes, memoryview(b"")
 
+    def trace(self, fileno: int) -> None:
+        if sys.platform != "linux":
+            raise e.NotSupportedError("only supported on Linux")
+        stream = impl.fdopen(fileno, b"w")
+        impl.PQtrace(self._pgconn_ptr, stream)
+
+    def set_trace_flags(self, flags: Trace) -> None:
+        impl.PQsetTraceFlags(self._pgconn_ptr, flags)
+
+    def untrace(self) -> None:
+        impl.PQuntrace(self._pgconn_ptr)
+
     def encrypt_password(
         self, passwd: bytes, user: bytes, algorithm: Optional[bytes] = None
     ) -> bytes:
index 78f7d56c70dad91a52b819e91c393bc7c8abbb10..7d28a4398f7b95641d4399047bac6e152fbe2770 100644 (file)
@@ -4,6 +4,11 @@ Libpq header definition for the cython psycopg.pq implementation.
 
 # Copyright (C) 2020-2021 The Psycopg Team
 
+cdef extern from "stdio.h":
+
+    ctypedef struct FILE:
+        pass
+
 cdef extern from "pg_config.h":
 
     int PG_VERSION_NUM
@@ -252,6 +257,11 @@ cdef extern from "libpq-fe.h":
     int PQputCopyEnd(PGconn *conn, const char *errormsg)
     int PQgetCopyData(PGconn *conn, char **buffer, int async)
 
+    # 33.10. Control Functions
+    void PQtrace(PGconn *conn, FILE *stream);
+    void PQsetTraceFlags(PGconn *conn, int flags);
+    void PQuntrace(PGconn *conn);
+
     # 33.11. Miscellaneous Functions
     void PQfreemem(void *ptr) nogil
     void PQconninfoFree(PQconninfoOption *connOptions)
index 0d5ee0c1bd32f335acfc0c25203980a63a3bedbd..c579311fe9257216305ec4a1c29da1cf1d940581 100644 (file)
@@ -17,13 +17,15 @@ cdef extern from * nogil:
     """
     pid_t getpid()
 
+from libc.stdio cimport fdopen
 from cpython.mem cimport PyMem_Malloc, PyMem_Free
 from cpython.bytes cimport PyBytes_AsString
 from cpython.memoryview cimport PyMemoryView_FromObject
 
 import logging
+import sys
 
-from psycopg.pq import Format as PqFormat
+from psycopg.pq import Format as PqFormat, Trace
 from psycopg.pq.misc import PGnotify, connection_summary
 from psycopg_c.pq cimport PQBuffer
 
@@ -504,6 +506,23 @@ cdef class PGconn:
         else:
             return nbytes, b""  # won't parse it, doesn't really be memoryview
 
+    def trace(self, fileno: int) -> None:
+        if sys.platform != "linux":
+            raise e.NotSupportedError("only supported on Linux")
+        stream = fdopen(fileno, b"w")
+        libpq.PQtrace(self._pgconn_ptr, stream)
+
+    def set_trace_flags(self, flags: Trace) -> None:
+        if libpq.PG_VERSION_NUM < 140000:
+            raise e.NotSupportedError(
+                f"PQsetTraceFlags requires libpq from PostgreSQL 14,"
+                f" {libpq.PG_VERSION_NUM} available instead"
+            )
+        libpq.PQsetTraceFlags(self._pgconn_ptr, flags)
+
+    def untrace(self) -> None:
+        libpq.PQuntrace(self._pgconn_ptr)
+
     def encrypt_password(
         self, const char *passwd, const char *user, algorithm = None
     ) -> bytes:
index 7b2e75e02fac8cf0c664f4a03a6cb2edc578a02c..fefab51bbf52f1f551894df74d25a9ef5f8e8751 100644 (file)
@@ -471,6 +471,50 @@ def test_notice_error(pgconn, caplog):
     assert "hello error" in rec.message
 
 
+@pytest.mark.libpq("< 14")
+@pytest.mark.skipif("sys.platform != 'linux'")
+def test_trace_pre14(pgconn, tmp_path):
+    tracef = tmp_path / "trace"
+    with tracef.open("w") as f:
+        pgconn.trace(f.fileno())
+        with pytest.raises(psycopg.NotSupportedError):
+            pgconn.set_trace_flags(0)
+        pgconn.exec_(b"select 1")
+        pgconn.untrace()
+        pgconn.exec_(b"select 2")
+    traces = tracef.read_text()
+    assert "select 1" in traces
+    assert "select 2" not in traces
+
+
+@pytest.mark.libpq(">= 14")
+@pytest.mark.skipif("sys.platform != 'linux'")
+def test_trace(pgconn, tmp_path):
+    tracef = tmp_path / "trace"
+    with tracef.open("w") as f:
+        pgconn.trace(f.fileno())
+        pgconn.set_trace_flags(
+            pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE
+        )
+        pgconn.exec_(b"select 1")
+        pgconn.untrace()
+        pgconn.exec_(b"select 2")
+    traces = [line.split("\t") for line in tracef.read_text().splitlines()]
+    assert traces == [
+        ["F", "13", "Query", ' "select 1"'],
+        ["B", "33", "RowDescription", ' 1 "?column?" NNNN 0 NNNN 4 -1 0'],
+        ["B", "11", "DataRow", " 1 1 '1'"],
+        ["B", "13", "CommandComplete", ' "SELECT 1"'],
+        ["B", "5", "ReadyForQuery", " I"],
+    ]
+
+
+@pytest.mark.skipif("sys.platform == 'linux'")
+def test_trace_nonlinux(pgconn):
+    with pytest.raises(psycopg.NotSupportedError):
+        pgconn.trace(1)
+
+
 @pytest.mark.libpq(">= 10")
 def test_encrypt_password(pgconn):
     enc = pgconn.encrypt_password(b"psycopg2", b"ashesh", b"md5")