]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
test: add fixture to write tests based on libpq trace data
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 29 Mar 2022 19:35:48 +0000 (21:35 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 2 Apr 2022 23:23:22 +0000 (01:23 +0200)
tests/fix_pq.py

index 055178c564543f60855de6ec1d3c1dae5cbbe117..912378dabebe87486b8d18a716e5a84cfc0c7845 100644 (file)
@@ -1,9 +1,16 @@
 import sys
+from typing import Iterator, List, NamedTuple
+from tempfile import TemporaryFile
 
 import pytest
 
 from .utils import check_libpq_version
 
+try:
+    from psycopg import pq
+except ImportError:
+    pq = None  # type: ignore
+
 
 def pytest_report_header(config):
     try:
@@ -28,8 +35,6 @@ def pytest_configure(config):
 
 
 def pytest_runtest_setup(item):
-    from psycopg import pq
-
     for m in item.iter_markers(name="libpq"):
         assert len(m.args) == 1
         msg = check_libpq_version(pq.version(), m.args[0])
@@ -51,9 +56,71 @@ def libpq():
         assert libname, "libpq libname not found"
         return ctypes.pydll.LoadLibrary(libname)
     except Exception as e:
-        from psycopg import pq
-
         if pq.__impl__ == "binary":
             pytest.skip(f"can't load libpq for testing: {e}")
         else:
             raise
+
+
+@pytest.fixture
+def trace(libpq):
+    pqver = pq.__build_version__ or pq.version()
+    if pqver < 140000:
+        pytest.skip(f"trace not available on libpq {pqver}")
+    if sys.platform != "linux":
+        pytest.skip(f"trace not available on {sys.platform}")
+
+    yield Tracer()
+
+
+class Tracer:
+    def trace(self, conn):
+        pgconn: "pq.abc.PGconn"
+
+        if hasattr(conn, "exec_"):
+            pgconn = conn
+        elif hasattr(conn, "cursor"):
+            pgconn = conn.pgconn
+        else:
+            raise Exception()
+
+        return TraceLog(pgconn)
+
+
+class TraceLog:
+    def __init__(self, pgconn: "pq.abc.PGconn"):
+        self.pgconn = pgconn
+        self.tempfile = TemporaryFile(buffering=0)
+        pgconn.trace(self.tempfile.fileno())
+        pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS)
+
+    def __del__(self):
+        if self.pgconn.status == pq.ConnStatus.OK:
+            self.pgconn.untrace()
+        self.tempfile.close()
+
+    def __iter__(self) -> "Iterator[TraceEntry]":
+        self.tempfile.seek(0)
+        data = self.tempfile.read()
+        for entry in self._parse_entries(data):
+            yield entry
+
+    def _parse_entries(self, data: bytes) -> "Iterator[TraceEntry]":
+        for line in data.splitlines():
+            direction, length, type, *content = line.split(b"\t")
+            yield TraceEntry(
+                direction=direction.decode(),
+                length=int(length.decode()),
+                type=type.decode(),
+                # Note: the items encoding is not very solid: no escaped
+                # backslash, no escaped quotes.
+                # At the moment we don't need a proper parser.
+                content=[content[0]] if content else [],
+            )
+
+
+class TraceEntry(NamedTuple):
+    direction: str
+    length: int
+    type: str
+    content: List[bytes]