From: Daniele Varrazzo Date: Tue, 17 May 2022 00:28:53 +0000 (+0200) Subject: feat(crdb): cusomize CockroachDB connection X-Git-Tag: 3.1~49^2~64 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=066b5dfb96ddee83b85dc1f10b48ac2c19e3b43e;p=thirdparty%2Fpsycopg.git feat(crdb): cusomize CockroachDB connection - Add ConnectionInfo.vendor - Add ConnectionInfo subclassing - Add CrdbConnectionInfo subclass with crdb_version attribute - Add 'psycopg.crdb' module with adapters - Dump strings using the text oid by default on CockroachDB The latter change might have wider consequences, however crdb casts strings to other types more easily than what Postgres does. At a glance it might work; porting the rest of the test suite will tell. test_adapt tests ported. Tests showing a difference moved to crdb/test_adapt. --- diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index 0cc24ea20..7937c4552 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -114,10 +114,15 @@ class BaseConnection(Generic[Row]): ConnStatus = pq.ConnStatus TransactionStatus = pq.TransactionStatus + _info_class: Type[ConnectionInfo] = ConnectionInfo + def __init__(self, pgconn: "PGconn"): self.pgconn = pgconn self._autocommit = False - self._adapters = AdaptersMap(postgres.adapters) + + # None, but set to a copy of the global adapters map as soon as requested. + self._adapters: Optional[AdaptersMap] = None + self._notice_handlers: List[NoticeHandler] = [] self._notify_handlers: List[NotifyHandler] = [] @@ -280,10 +285,13 @@ class BaseConnection(Generic[Row]): @property def info(self) -> ConnectionInfo: """A `ConnectionInfo` attribute to inspect connection properties.""" - return ConnectionInfo(self.pgconn) + return self._info_class(self.pgconn) @property def adapters(self) -> AdaptersMap: + if not self._adapters: + self._adapters = AdaptersMap(postgres.adapters) + return self._adapters @property @@ -732,7 +740,15 @@ class Connection(BaseConnection[Row]): rv.cursor_factory = cursor_factory if context: rv._adapters = AdaptersMap(context.adapters) + rv.prepare_threshold = prepare_threshold + + # TODOCRDB find the right place for this operation + if rv.pgconn.parameter_status(b"crdb_version"): + from .crdb import customize_crdb_connection + + customize_crdb_connection(rv) + return rv def __enter__(self: _Self) -> _Self: diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 1545667b7..334e60dde 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -136,6 +136,13 @@ class AsyncConnection(BaseConnection[Row]): if context: rv._adapters = AdaptersMap(context.adapters) rv.prepare_threshold = prepare_threshold + + # TODOCRDB find the right place for this operation + if rv.pgconn.parameter_status(b"crdb_version"): + from .crdb import customize_crdb_connection + + customize_crdb_connection(rv) + return rv async def __aenter__(self: _Self) -> _Self: diff --git a/psycopg/psycopg/conninfo.py b/psycopg/psycopg/conninfo.py index 1bc85ad5c..f845e8255 100644 --- a/psycopg/psycopg/conninfo.py +++ b/psycopg/psycopg/conninfo.py @@ -124,6 +124,10 @@ class ConnectionInfo: def __init__(self, pgconn: pq.abc.PGconn): self.pgconn = pgconn + @property + def vendor(self) -> str: + return "PostgreSQL" + @property def host(self) -> str: """The server host name of the active connection. See :pq:`PQhost()`.""" diff --git a/psycopg/psycopg/crdb.py b/psycopg/psycopg/crdb.py new file mode 100644 index 000000000..e509d78ec --- /dev/null +++ b/psycopg/psycopg/crdb.py @@ -0,0 +1,81 @@ +""" +Types configuration specific for CockroachDB. +""" + +# Copyright (C) 2022 The Psycopg Team + +import re +from typing import Any, Optional, Union, TYPE_CHECKING + +from . import errors as e +from .abc import AdaptContext +from .postgres import adapters as pg_adapters +from ._adapters_map import AdaptersMap +from .conninfo import ConnectionInfo + +adapters = AdaptersMap(pg_adapters) + +if TYPE_CHECKING: + from .connection import Connection + from .connection_async import AsyncConnection + + +class CrdbConnectionInfo(ConnectionInfo): + @property + def vendor(self) -> str: + return "CockroachDB" + + @property + def crdb_version(self) -> int: + """ + Return the CockroachDB server version connected. + + Return None if the server is not CockroachDB, else return a number in + the PostgreSQL format (e.g. 21.2.10 -> 200210) + + Assume all the connections are on the same db: return a cached result on + following calls. + """ + sver = self.parameter_status("crdb_version") + if not sver: + raise e.InternalError("'crdb_version' parameter status not set") + + ver = self.parse_crdb_version(sver) + if ver is None: + raise e.InterfaceError(f"couldn't parse CockroachDB version from: {sver!r}") + + return ver + + @classmethod + def parse_crdb_version(self, sver: str) -> Optional[int]: + m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver) + if not m: + return None + + return int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3)) + + +def register_crdb_adapters(context: AdaptContext) -> None: + from .types import string + + adapters = context.adapters + + # Dump strings with text oid instead of unknown. + # Unlike PostgreSQL, CRDB seems able to cast text to most types. + adapters.register_dumper(str, string.StrDumper) + + +register_crdb_adapters(adapters) + + +def customize_crdb_connection( + conn: "Union[Connection[Any], AsyncConnection[Any]]", +) -> None: + conn._info_class = CrdbConnectionInfo + + # TODOCRDB: what if someone is passing context? they will have + # customised the postgres adapters, so those changes wouldn't apply + # to crdb (e.g. the Django backend in preparation). + if conn._adapters is None: + # Not customized by connect() + conn._adapters = AdaptersMap(adapters) diff --git a/tests/crdb/__init__.py b/tests/crdb/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/crdb/test_adapt.py b/tests/crdb/test_adapt.py new file mode 100644 index 000000000..2b64e3ff6 --- /dev/null +++ b/tests/crdb/test_adapt.py @@ -0,0 +1,79 @@ +from copy import deepcopy + +import pytest + +import psycopg.crdb +from psycopg.adapt import PyFormat, Transformer +from psycopg.types.array import ListDumper +from psycopg.postgres import types as builtins + +from ..test_adapt import MyStr, make_dumper, make_bin_dumper +from ..test_adapt import make_loader, make_bin_loader + +pytestmark = pytest.mark.crdb + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_return_untyped(conn, fmt_in): + # Analyze and check for changes using strings in untyped/typed contexts + cur = conn.cursor() + # Currently string are passed as text oid to CockroachDB, unlike Postgres, + # to which strings are passed as unknown. This is because CRDB doesn't + # allow the unknown oid to be emitted; execute("SELECT %s", ["str"]) raises + # an error. However, unlike PostgreSQL, text can be cast to any other type. + cur.execute(f"select %{fmt_in.value}, %{fmt_in.value}", ["hello", 10]) + assert cur.fetchone() == ("hello", 10) + + cur.execute("create table testjson(data jsonb)") + cur.execute(f"insert into testjson (data) values (%{fmt_in.value})", ["{}"]) + assert cur.execute("select data from testjson").fetchone() == ({},) + + +def test_str_list_dumper_text(conn): + t = Transformer(conn) + dstr = t.get_dumper([""], PyFormat.TEXT) + assert isinstance(dstr, ListDumper) + assert dstr.oid == builtins["text"].array_oid + assert dstr.sub_dumper and dstr.sub_dumper.oid == builtins["text"].oid + + +@pytest.fixture +def crdb_adapters(): + """Restore the crdb adapters after a test has changed them.""" + from psycopg.crdb import adapters + + dumpers = deepcopy(adapters._dumpers) + dumpers_by_oid = deepcopy(adapters._dumpers_by_oid) + loaders = deepcopy(adapters._loaders) + types = list(adapters.types) + + yield None + + adapters._dumpers = dumpers + adapters._dumpers_by_oid = dumpers_by_oid + adapters._loaders = loaders + adapters.types.clear() + for t in types: + adapters.types.add(t) + + +def test_dump_global_ctx(dsn, crdb_adapters, pgconn): + psycopg.crdb.adapters.register_dumper(MyStr, make_bin_dumper("gb")) + psycopg.crdb.adapters.register_dumper(MyStr, make_dumper("gt")) + with psycopg.connect(dsn) as conn: + cur = conn.execute("select %s", [MyStr("hello")]) + assert cur.fetchone() == ("hellogt",) + cur = conn.execute("select %b", [MyStr("hello")]) + assert cur.fetchone() == ("hellogb",) + cur = conn.execute("select %t", [MyStr("hello")]) + assert cur.fetchone() == ("hellogt",) + + +def test_load_global_ctx(dsn, crdb_adapters): + psycopg.crdb.adapters.register_loader("text", make_loader("gt")) + psycopg.crdb.adapters.register_loader("text", make_bin_loader("gb")) + with psycopg.connect(dsn) as conn: + cur = conn.cursor(binary=False).execute("select 'hello'::text") + assert cur.fetchone() == ("hellogt",) + cur = conn.cursor(binary=True).execute("select 'hello'::text") + assert cur.fetchone() == ("hellogb",) diff --git a/tests/crdb/test_conninfo.py b/tests/crdb/test_conninfo.py new file mode 100644 index 000000000..2b9f58b09 --- /dev/null +++ b/tests/crdb/test_conninfo.py @@ -0,0 +1,11 @@ +import pytest + +pytestmark = pytest.mark.crdb + + +def test_vendor(conn): + assert conn.info.vendor == "CockroachDB" + + +def test_crdb_version(conn): + assert conn.info.crdb_version > 200000 diff --git a/tests/fix_crdb.py b/tests/fix_crdb.py index e87e4215a..06ee3aaad 100644 --- a/tests/fix_crdb.py +++ b/tests/fix_crdb.py @@ -1,5 +1,3 @@ -import re - from .utils import check_version @@ -26,7 +24,7 @@ def pytest_runtest_setup(item): item.function.crdb_reason = m.kwargs.get("reason") -def check_crdb_version(pgconn, func): +def check_crdb_version(got, func): """ Verify if the CockroachDB version is a version accepted. @@ -39,13 +37,14 @@ def check_crdb_version(pgconn, func): and skips the test if the server version doesn't match what expected. """ want = func.want_crdb - got = get_crdb_version(pgconn) + rv = None + if got is None: if want == "only": return "skipping test: CockroachDB only" else: if want == "only": - rv = None + pass elif want == "skip": rv = "skipping test: not supported on CockroachDB" else: @@ -60,37 +59,8 @@ def check_crdb_version(pgconn, func): f"issues/{crdb_reasons[func.crdb_reason]}" ) rv = f"{rv} ({url})" - return rv - - -def get_crdb_version(pgconn, __crdb_version=[]): - """ - Return the CockroachDB server version connected. - - Return None if the server is not CockroachDB, else return a number in - the PostgreSQL format (e.g. 21.2.10 -> 200210) - - Assume all the connections are on the same db: return a cached result on - following calls. - """ - if __crdb_version: - return __crdb_version[0] - - bver = pgconn.parameter_status(b"crdb_version") - if bver: - sver = bver.decode() - m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver) - if not m: - raise ValueError(f"can't parse CockroachDB version from {sver}") - - ver = int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3)) - - else: - ver = None - - __crdb_version.append(ver) - return __crdb_version[0] + return rv # mapping from reason description to ticket number diff --git a/tests/fix_db.py b/tests/fix_db.py index 15f9540ce..f4e6ed6b4 100644 --- a/tests/fix_db.py +++ b/tests/fix_db.py @@ -2,13 +2,17 @@ import os import pytest import logging from contextlib import contextmanager -from typing import List +from typing import List, Optional import psycopg from psycopg import pq from .utils import check_libpq_version, check_server_version +# Set by warm_up_database() the first time the dsn fixture is used +pg_version: int +crdb_version: Optional[int] + def pytest_addoption(parser): parser.addoption( @@ -70,8 +74,10 @@ def pytest_runtest_setup(item): @pytest.fixture(scope="session") -def dsn(request): - """Return the dsn used to connect to the `--test-dsn` database.""" +def session_dsn(request): + """ + Return the dsn used to connect to the `--test-dsn` database (session-wide). + """ dsn = request.config.getoption("--test-dsn") if dsn is None: pytest.skip("skipping test as no --test-dsn") @@ -87,6 +93,16 @@ def dsn(request): return dsn +@pytest.fixture +def dsn(session_dsn, request): + """Return the dsn used to connect to the `--test-dsn` database.""" + msg = check_connection_version(request.function) + if msg: + pytest.skip(msg) + + return session_dsn + + @pytest.fixture(scope="session") def tracefile(request): """Open and yield a file for libpq client/server communication traces if @@ -128,15 +144,15 @@ def maybe_trace(pgconn, tracefile, function): @pytest.fixture def pgconn(dsn, request, tracefile): """Return a PGconn connection open to `--test-dsn`.""" + msg = check_connection_version(request.function) + if msg: + pytest.skip(msg) + from psycopg import pq conn = pq.PGconn.connect(dsn.encode()) if conn.status != pq.ConnStatus.OK: pytest.fail(f"bad connection: {conn.error_message.decode('utf8', 'replace')}") - msg = check_connection_version(conn, request.function) - if msg: - conn.finish() - pytest.skip(msg) with maybe_trace(conn, tracefile, request.function): yield conn @@ -147,13 +163,13 @@ def pgconn(dsn, request, tracefile): @pytest.fixture def conn(dsn, request, tracefile): """Return a `Connection` connected to the ``--test-dsn`` database.""" + msg = check_connection_version(request.function) + if msg: + pytest.skip(msg) + from psycopg import Connection conn = Connection.connect(dsn) - msg = check_connection_version(conn.pgconn, request.function) - if msg: - conn.close() - pytest.skip(msg) with maybe_trace(conn.pgconn, tracefile, request.function): yield conn conn.close() @@ -175,13 +191,13 @@ def pipeline(request, conn): @pytest.fixture async def aconn(dsn, request, tracefile): """Return an `AsyncConnection` connected to the ``--test-dsn`` database.""" + msg = check_connection_version(request.function) + if msg: + pytest.skip(msg) + from psycopg import AsyncConnection conn = await AsyncConnection.connect(dsn) - msg = check_connection_version(conn.pgconn, request.function) - if msg: - await conn.close() - pytest.skip(msg) with maybe_trace(conn.pgconn, tracefile, request.function): yield conn await conn.close() @@ -201,13 +217,13 @@ async def apipeline(request, aconn): @pytest.fixture(scope="session") -def svcconn(dsn): +def svcconn(session_dsn): """ Return a session `Connection` connected to the ``--test-dsn`` database. """ from psycopg import Connection - conn = Connection.connect(dsn, autocommit=True) + conn = Connection.connect(session_dsn, autocommit=True) yield conn conn.close() @@ -254,16 +270,22 @@ class ListPopAll(list): # type: ignore[type-arg] return out -def check_connection_version(pgconn, function): +def check_connection_version(function): + try: + pg_version + except NameError: + # First connection creation failed. Let the tests fail. + return None + if hasattr(function, "want_pg_version"): - rv = check_server_version(pgconn, function.want_pg_version) + rv = check_server_version(pg_version, function) if rv: return rv if hasattr(function, "want_crdb"): from .fix_crdb import check_crdb_version - rv = check_crdb_version(pgconn, function) + rv = check_crdb_version(crdb_version, function) if rv: return rv @@ -294,7 +316,15 @@ def warm_up_database(dsn: str, __first_connection: List[bool] = [True]) -> None: return del __first_connection[:] + global pg_version, crdb_version + import psycopg with psycopg.connect(dsn, connect_timeout=10) as conn: conn.execute("select 1") + + pg_version = conn.info.server_version + if conn.info.vendor == "CockroachDB": + crdb_version = conn.info.crdb_version # type: ignore + else: + crdb_version = None diff --git a/tests/test_adapt.py b/tests/test_adapt.py index 7437a71ab..7e76cb489 100644 --- a/tests/test_adapt.py +++ b/tests/test_adapt.py @@ -61,17 +61,17 @@ def test_register_dumper_by_class_name(conn): assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is dumper -def test_dump_global_ctx(dsn, global_adapters): +@pytest.mark.crdb("skip", reason="global adapters don't affect crdb") +def test_dump_global_ctx(dsn, global_adapters, pgconn): psycopg.adapters.register_dumper(MyStr, make_bin_dumper("gb")) psycopg.adapters.register_dumper(MyStr, make_dumper("gt")) - conn = psycopg.connect(dsn) - cur = conn.execute("select %s", [MyStr("hello")]) - assert cur.fetchone() == ("hellogt",) - cur = conn.execute("select %b", [MyStr("hello")]) - assert cur.fetchone() == ("hellogb",) - cur = conn.execute("select %t", [MyStr("hello")]) - assert cur.fetchone() == ("hellogt",) - conn.close() + with psycopg.connect(dsn) as conn: + cur = conn.execute("select %s", [MyStr("hello")]) + assert cur.fetchone() == ("hellogt",) + cur = conn.execute("select %b", [MyStr("hello")]) + assert cur.fetchone() == ("hellogb",) + cur = conn.execute("select %t", [MyStr("hello")]) + assert cur.fetchone() == ("hellogt",) def test_dump_connection_ctx(conn): @@ -198,15 +198,15 @@ def test_register_loader_by_type_name(conn): assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is loader +@pytest.mark.crdb("skip", reason="global adapters don't affect crdb") def test_load_global_ctx(dsn, global_adapters): psycopg.adapters.register_loader("text", make_loader("gt")) psycopg.adapters.register_loader("text", make_bin_loader("gb")) - conn = psycopg.connect(dsn) - cur = conn.cursor(binary=False).execute("select 'hello'::text") - assert cur.fetchone() == ("hellogt",) - cur = conn.cursor(binary=True).execute("select 'hello'::text") - assert cur.fetchone() == ("hellogb",) - conn.close() + with psycopg.connect(dsn) as conn: + cur = conn.cursor(binary=False).execute("select 'hello'::text") + assert cur.fetchone() == ("hellogt",) + cur = conn.cursor(binary=True).execute("select 'hello'::text") + assert cur.fetchone() == ("hellogb",) def test_load_connection_ctx(conn): @@ -289,7 +289,7 @@ def test_load_cursor_ctx_nested(conn, sql, obj, fmt_out): @pytest.mark.parametrize("fmt_out", pq.Format) -def test_array_dumper(conn, fmt_out): +def test_list_dumper(conn, fmt_out): t = Transformer(conn) fmt_in = PyFormat.from_pq(fmt_out) dint = t.get_dumper([0], fmt_in) @@ -298,15 +298,6 @@ def test_array_dumper(conn, fmt_out): assert dint.sub_dumper and dint.sub_dumper.oid == builtins["int2"].oid dstr = t.get_dumper([""], fmt_in) - if fmt_in == PyFormat.BINARY: - assert isinstance(dstr, ListBinaryDumper) - assert dstr.oid == builtins["text"].array_oid - assert dstr.sub_dumper and dstr.sub_dumper.oid == builtins["text"].oid - else: - assert isinstance(dstr, ListDumper) - assert dstr.oid == 0 - assert dstr.sub_dumper and dstr.sub_dumper.oid == 0 - assert dstr is not dint assert t.get_dumper([1], fmt_in) is dint @@ -323,6 +314,23 @@ def test_array_dumper(conn, fmt_out): assert t.get_dumper(L, fmt_in) +@pytest.mark.crdb("skip", reason="test in crdb test suite") +def test_str_list_dumper_text(conn): + t = Transformer(conn) + dstr = t.get_dumper([""], PyFormat.TEXT) + assert isinstance(dstr, ListDumper) + assert dstr.oid == 0 + assert dstr.sub_dumper and dstr.sub_dumper.oid == 0 + + +def test_str_list_dumper_binary(conn): + t = Transformer(conn) + dstr = t.get_dumper([""], PyFormat.BINARY) + assert isinstance(dstr, ListBinaryDumper) + assert dstr.oid == builtins["text"].array_oid + assert dstr.sub_dumper and dstr.sub_dumper.oid == builtins["text"].oid + + def test_last_dumper_registered_ctx(conn): cur = conn.cursor() @@ -347,6 +355,7 @@ def test_none_type_argument(conn, fmt_in): assert cur.fetchone()[0] +@pytest.mark.crdb("skip", reason="test in crdb test suite") @pytest.mark.parametrize("fmt_in", PyFormat) def test_return_untyped(conn, fmt_in): # Analyze and check for changes using strings in untyped/typed contexts diff --git a/tests/test_conninfo.py b/tests/test_conninfo.py index d637ecbdc..5f0c0338b 100644 --- a/tests/test_conninfo.py +++ b/tests/test_conninfo.py @@ -302,6 +302,9 @@ class TestConnectionInfo: with pytest.raises(psycopg.NotSupportedError): cur.execute("select 'x'") + def test_vendor(self, conn): + assert conn.info.vendor + @pytest.mark.parametrize( "conninfo, want, env", diff --git a/tests/test_psycopg_dbapi20.py b/tests/test_psycopg_dbapi20.py index fc6bec44a..28c79bd3b 100644 --- a/tests/test_psycopg_dbapi20.py +++ b/tests/test_psycopg_dbapi20.py @@ -10,8 +10,8 @@ from . import dbapi20_tpc @pytest.fixture(scope="class") -def with_dsn(request, dsn): - request.cls.connect_args = (dsn,) +def with_dsn(request, session_dsn): + request.cls.connect_args = (session_dsn,) @pytest.mark.usefixtures("with_dsn") diff --git a/tests/utils.py b/tests/utils.py index 830eba614..1b1dfa88e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,7 +20,7 @@ def check_libpq_version(got, want): return check_version(got, want, "libpq") -def check_server_version(pgconn, want): +def check_server_version(got, function): """ Verify if the server version is a version accepted. @@ -30,7 +30,7 @@ def check_server_version(pgconn, want): and skips the test if the server version doesn't match what expected. """ - got = pgconn.server_version + want = function.want_pg_version return check_version(got, want, "server")