From: Daniele Varrazzo Date: Tue, 29 Mar 2022 19:35:48 +0000 (+0200) Subject: test: add fixture to write tests based on libpq trace data X-Git-Tag: 3.1~145^2~6 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e2934f356cf4311cffb40ce55572e5b68f639a9d;p=thirdparty%2Fpsycopg.git test: add fixture to write tests based on libpq trace data --- diff --git a/tests/fix_pq.py b/tests/fix_pq.py index 055178c56..912378dab 100644 --- a/tests/fix_pq.py +++ b/tests/fix_pq.py @@ -1,9 +1,16 @@ import sys +from typing import Iterator, List, NamedTuple +from tempfile import TemporaryFile import pytest from .utils import check_libpq_version +try: + from psycopg import pq +except ImportError: + pq = None # type: ignore + def pytest_report_header(config): try: @@ -28,8 +35,6 @@ def pytest_configure(config): def pytest_runtest_setup(item): - from psycopg import pq - for m in item.iter_markers(name="libpq"): assert len(m.args) == 1 msg = check_libpq_version(pq.version(), m.args[0]) @@ -51,9 +56,71 @@ def libpq(): assert libname, "libpq libname not found" return ctypes.pydll.LoadLibrary(libname) except Exception as e: - from psycopg import pq - if pq.__impl__ == "binary": pytest.skip(f"can't load libpq for testing: {e}") else: raise + + +@pytest.fixture +def trace(libpq): + pqver = pq.__build_version__ or pq.version() + if pqver < 140000: + pytest.skip(f"trace not available on libpq {pqver}") + if sys.platform != "linux": + pytest.skip(f"trace not available on {sys.platform}") + + yield Tracer() + + +class Tracer: + def trace(self, conn): + pgconn: "pq.abc.PGconn" + + if hasattr(conn, "exec_"): + pgconn = conn + elif hasattr(conn, "cursor"): + pgconn = conn.pgconn + else: + raise Exception() + + return TraceLog(pgconn) + + +class TraceLog: + def __init__(self, pgconn: "pq.abc.PGconn"): + self.pgconn = pgconn + self.tempfile = TemporaryFile(buffering=0) + pgconn.trace(self.tempfile.fileno()) + pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS) + + def __del__(self): + if self.pgconn.status == pq.ConnStatus.OK: + self.pgconn.untrace() + self.tempfile.close() + + def __iter__(self) -> "Iterator[TraceEntry]": + self.tempfile.seek(0) + data = self.tempfile.read() + for entry in self._parse_entries(data): + yield entry + + def _parse_entries(self, data: bytes) -> "Iterator[TraceEntry]": + for line in data.splitlines(): + direction, length, type, *content = line.split(b"\t") + yield TraceEntry( + direction=direction.decode(), + length=int(length.decode()), + type=type.decode(), + # Note: the items encoding is not very solid: no escaped + # backslash, no escaped quotes. + # At the moment we don't need a proper parser. + content=[content[0]] if content else [], + ) + + +class TraceEntry(NamedTuple): + direction: str + length: int + type: str + content: List[bytes]