ConnStatus = pq.ConnStatus
TransactionStatus = pq.TransactionStatus
+ _info_class: Type[ConnectionInfo] = ConnectionInfo
+
def __init__(self, pgconn: "PGconn"):
self.pgconn = pgconn
self._autocommit = False
- self._adapters = AdaptersMap(postgres.adapters)
+
+ # None, but set to a copy of the global adapters map as soon as requested.
+ self._adapters: Optional[AdaptersMap] = None
+
self._notice_handlers: List[NoticeHandler] = []
self._notify_handlers: List[NotifyHandler] = []
@property
def info(self) -> ConnectionInfo:
"""A `ConnectionInfo` attribute to inspect connection properties."""
- return ConnectionInfo(self.pgconn)
+ return self._info_class(self.pgconn)
@property
def adapters(self) -> AdaptersMap:
+ if not self._adapters:
+ self._adapters = AdaptersMap(postgres.adapters)
+
return self._adapters
@property
rv.cursor_factory = cursor_factory
if context:
rv._adapters = AdaptersMap(context.adapters)
+
rv.prepare_threshold = prepare_threshold
+
+ # TODOCRDB find the right place for this operation
+ if rv.pgconn.parameter_status(b"crdb_version"):
+ from .crdb import customize_crdb_connection
+
+ customize_crdb_connection(rv)
+
return rv
def __enter__(self: _Self) -> _Self:
if context:
rv._adapters = AdaptersMap(context.adapters)
rv.prepare_threshold = prepare_threshold
+
+ # TODOCRDB find the right place for this operation
+ if rv.pgconn.parameter_status(b"crdb_version"):
+ from .crdb import customize_crdb_connection
+
+ customize_crdb_connection(rv)
+
return rv
async def __aenter__(self: _Self) -> _Self:
def __init__(self, pgconn: pq.abc.PGconn):
self.pgconn = pgconn
+ @property
+ def vendor(self) -> str:
+ return "PostgreSQL"
+
@property
def host(self) -> str:
"""The server host name of the active connection. See :pq:`PQhost()`."""
--- /dev/null
+"""
+Types configuration specific for CockroachDB.
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+import re
+from typing import Any, Optional, Union, TYPE_CHECKING
+
+from . import errors as e
+from .abc import AdaptContext
+from .postgres import adapters as pg_adapters
+from ._adapters_map import AdaptersMap
+from .conninfo import ConnectionInfo
+
+adapters = AdaptersMap(pg_adapters)
+
+if TYPE_CHECKING:
+ from .connection import Connection
+ from .connection_async import AsyncConnection
+
+
+class CrdbConnectionInfo(ConnectionInfo):
+ @property
+ def vendor(self) -> str:
+ return "CockroachDB"
+
+ @property
+ def crdb_version(self) -> int:
+ """
+ Return the CockroachDB server version connected.
+
+ Return None if the server is not CockroachDB, else return a number in
+ the PostgreSQL format (e.g. 21.2.10 -> 200210)
+
+ Assume all the connections are on the same db: return a cached result on
+ following calls.
+ """
+ sver = self.parameter_status("crdb_version")
+ if not sver:
+ raise e.InternalError("'crdb_version' parameter status not set")
+
+ ver = self.parse_crdb_version(sver)
+ if ver is None:
+ raise e.InterfaceError(f"couldn't parse CockroachDB version from: {sver!r}")
+
+ return ver
+
+ @classmethod
+ def parse_crdb_version(self, sver: str) -> Optional[int]:
+ m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver)
+ if not m:
+ return None
+
+ return int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3))
+
+
+def register_crdb_adapters(context: AdaptContext) -> None:
+ from .types import string
+
+ adapters = context.adapters
+
+ # Dump strings with text oid instead of unknown.
+ # Unlike PostgreSQL, CRDB seems able to cast text to most types.
+ adapters.register_dumper(str, string.StrDumper)
+
+
+register_crdb_adapters(adapters)
+
+
+def customize_crdb_connection(
+ conn: "Union[Connection[Any], AsyncConnection[Any]]",
+) -> None:
+ conn._info_class = CrdbConnectionInfo
+
+ # TODOCRDB: what if someone is passing context? they will have
+ # customised the postgres adapters, so those changes wouldn't apply
+ # to crdb (e.g. the Django backend in preparation).
+ if conn._adapters is None:
+ # Not customized by connect()
+ conn._adapters = AdaptersMap(adapters)
--- /dev/null
+from copy import deepcopy
+
+import pytest
+
+import psycopg.crdb
+from psycopg.adapt import PyFormat, Transformer
+from psycopg.types.array import ListDumper
+from psycopg.postgres import types as builtins
+
+from ..test_adapt import MyStr, make_dumper, make_bin_dumper
+from ..test_adapt import make_loader, make_bin_loader
+
+pytestmark = pytest.mark.crdb
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_return_untyped(conn, fmt_in):
+ # Analyze and check for changes using strings in untyped/typed contexts
+ cur = conn.cursor()
+ # Currently string are passed as text oid to CockroachDB, unlike Postgres,
+ # to which strings are passed as unknown. This is because CRDB doesn't
+ # allow the unknown oid to be emitted; execute("SELECT %s", ["str"]) raises
+ # an error. However, unlike PostgreSQL, text can be cast to any other type.
+ cur.execute(f"select %{fmt_in.value}, %{fmt_in.value}", ["hello", 10])
+ assert cur.fetchone() == ("hello", 10)
+
+ cur.execute("create table testjson(data jsonb)")
+ cur.execute(f"insert into testjson (data) values (%{fmt_in.value})", ["{}"])
+ assert cur.execute("select data from testjson").fetchone() == ({},)
+
+
+def test_str_list_dumper_text(conn):
+ t = Transformer(conn)
+ dstr = t.get_dumper([""], PyFormat.TEXT)
+ assert isinstance(dstr, ListDumper)
+ assert dstr.oid == builtins["text"].array_oid
+ assert dstr.sub_dumper and dstr.sub_dumper.oid == builtins["text"].oid
+
+
+@pytest.fixture
+def crdb_adapters():
+ """Restore the crdb adapters after a test has changed them."""
+ from psycopg.crdb import adapters
+
+ dumpers = deepcopy(adapters._dumpers)
+ dumpers_by_oid = deepcopy(adapters._dumpers_by_oid)
+ loaders = deepcopy(adapters._loaders)
+ types = list(adapters.types)
+
+ yield None
+
+ adapters._dumpers = dumpers
+ adapters._dumpers_by_oid = dumpers_by_oid
+ adapters._loaders = loaders
+ adapters.types.clear()
+ for t in types:
+ adapters.types.add(t)
+
+
+def test_dump_global_ctx(dsn, crdb_adapters, pgconn):
+ psycopg.crdb.adapters.register_dumper(MyStr, make_bin_dumper("gb"))
+ psycopg.crdb.adapters.register_dumper(MyStr, make_dumper("gt"))
+ with psycopg.connect(dsn) as conn:
+ cur = conn.execute("select %s", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogt",)
+ cur = conn.execute("select %b", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogb",)
+ cur = conn.execute("select %t", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogt",)
+
+
+def test_load_global_ctx(dsn, crdb_adapters):
+ psycopg.crdb.adapters.register_loader("text", make_loader("gt"))
+ psycopg.crdb.adapters.register_loader("text", make_bin_loader("gb"))
+ with psycopg.connect(dsn) as conn:
+ cur = conn.cursor(binary=False).execute("select 'hello'::text")
+ assert cur.fetchone() == ("hellogt",)
+ cur = conn.cursor(binary=True).execute("select 'hello'::text")
+ assert cur.fetchone() == ("hellogb",)
--- /dev/null
+import pytest
+
+pytestmark = pytest.mark.crdb
+
+
+def test_vendor(conn):
+ assert conn.info.vendor == "CockroachDB"
+
+
+def test_crdb_version(conn):
+ assert conn.info.crdb_version > 200000
-import re
-
from .utils import check_version
item.function.crdb_reason = m.kwargs.get("reason")
-def check_crdb_version(pgconn, func):
+def check_crdb_version(got, func):
"""
Verify if the CockroachDB version is a version accepted.
and skips the test if the server version doesn't match what expected.
"""
want = func.want_crdb
- got = get_crdb_version(pgconn)
+ rv = None
+
if got is None:
if want == "only":
return "skipping test: CockroachDB only"
else:
if want == "only":
- rv = None
+ pass
elif want == "skip":
rv = "skipping test: not supported on CockroachDB"
else:
f"issues/{crdb_reasons[func.crdb_reason]}"
)
rv = f"{rv} ({url})"
- return rv
-
-
-def get_crdb_version(pgconn, __crdb_version=[]):
- """
- Return the CockroachDB server version connected.
-
- Return None if the server is not CockroachDB, else return a number in
- the PostgreSQL format (e.g. 21.2.10 -> 200210)
-
- Assume all the connections are on the same db: return a cached result on
- following calls.
- """
- if __crdb_version:
- return __crdb_version[0]
-
- bver = pgconn.parameter_status(b"crdb_version")
- if bver:
- sver = bver.decode()
- m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver)
- if not m:
- raise ValueError(f"can't parse CockroachDB version from {sver}")
-
- ver = int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3))
-
- else:
- ver = None
-
- __crdb_version.append(ver)
- return __crdb_version[0]
+ return rv
# mapping from reason description to ticket number
import pytest
import logging
from contextlib import contextmanager
-from typing import List
+from typing import List, Optional
import psycopg
from psycopg import pq
from .utils import check_libpq_version, check_server_version
+# Set by warm_up_database() the first time the dsn fixture is used
+pg_version: int
+crdb_version: Optional[int]
+
def pytest_addoption(parser):
parser.addoption(
@pytest.fixture(scope="session")
-def dsn(request):
- """Return the dsn used to connect to the `--test-dsn` database."""
+def session_dsn(request):
+ """
+ Return the dsn used to connect to the `--test-dsn` database (session-wide).
+ """
dsn = request.config.getoption("--test-dsn")
if dsn is None:
pytest.skip("skipping test as no --test-dsn")
return dsn
+@pytest.fixture
+def dsn(session_dsn, request):
+ """Return the dsn used to connect to the `--test-dsn` database."""
+ msg = check_connection_version(request.function)
+ if msg:
+ pytest.skip(msg)
+
+ return session_dsn
+
+
@pytest.fixture(scope="session")
def tracefile(request):
"""Open and yield a file for libpq client/server communication traces if
@pytest.fixture
def pgconn(dsn, request, tracefile):
"""Return a PGconn connection open to `--test-dsn`."""
+ msg = check_connection_version(request.function)
+ if msg:
+ pytest.skip(msg)
+
from psycopg import pq
conn = pq.PGconn.connect(dsn.encode())
if conn.status != pq.ConnStatus.OK:
pytest.fail(f"bad connection: {conn.error_message.decode('utf8', 'replace')}")
- msg = check_connection_version(conn, request.function)
- if msg:
- conn.finish()
- pytest.skip(msg)
with maybe_trace(conn, tracefile, request.function):
yield conn
@pytest.fixture
def conn(dsn, request, tracefile):
"""Return a `Connection` connected to the ``--test-dsn`` database."""
+ msg = check_connection_version(request.function)
+ if msg:
+ pytest.skip(msg)
+
from psycopg import Connection
conn = Connection.connect(dsn)
- msg = check_connection_version(conn.pgconn, request.function)
- if msg:
- conn.close()
- pytest.skip(msg)
with maybe_trace(conn.pgconn, tracefile, request.function):
yield conn
conn.close()
@pytest.fixture
async def aconn(dsn, request, tracefile):
"""Return an `AsyncConnection` connected to the ``--test-dsn`` database."""
+ msg = check_connection_version(request.function)
+ if msg:
+ pytest.skip(msg)
+
from psycopg import AsyncConnection
conn = await AsyncConnection.connect(dsn)
- msg = check_connection_version(conn.pgconn, request.function)
- if msg:
- await conn.close()
- pytest.skip(msg)
with maybe_trace(conn.pgconn, tracefile, request.function):
yield conn
await conn.close()
@pytest.fixture(scope="session")
-def svcconn(dsn):
+def svcconn(session_dsn):
"""
Return a session `Connection` connected to the ``--test-dsn`` database.
"""
from psycopg import Connection
- conn = Connection.connect(dsn, autocommit=True)
+ conn = Connection.connect(session_dsn, autocommit=True)
yield conn
conn.close()
return out
-def check_connection_version(pgconn, function):
+def check_connection_version(function):
+ try:
+ pg_version
+ except NameError:
+ # First connection creation failed. Let the tests fail.
+ return None
+
if hasattr(function, "want_pg_version"):
- rv = check_server_version(pgconn, function.want_pg_version)
+ rv = check_server_version(pg_version, function)
if rv:
return rv
if hasattr(function, "want_crdb"):
from .fix_crdb import check_crdb_version
- rv = check_crdb_version(pgconn, function)
+ rv = check_crdb_version(crdb_version, function)
if rv:
return rv
return
del __first_connection[:]
+ global pg_version, crdb_version
+
import psycopg
with psycopg.connect(dsn, connect_timeout=10) as conn:
conn.execute("select 1")
+
+ pg_version = conn.info.server_version
+ if conn.info.vendor == "CockroachDB":
+ crdb_version = conn.info.crdb_version # type: ignore
+ else:
+ crdb_version = None
assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is dumper
-def test_dump_global_ctx(dsn, global_adapters):
+@pytest.mark.crdb("skip", reason="global adapters don't affect crdb")
+def test_dump_global_ctx(dsn, global_adapters, pgconn):
psycopg.adapters.register_dumper(MyStr, make_bin_dumper("gb"))
psycopg.adapters.register_dumper(MyStr, make_dumper("gt"))
- conn = psycopg.connect(dsn)
- cur = conn.execute("select %s", [MyStr("hello")])
- assert cur.fetchone() == ("hellogt",)
- cur = conn.execute("select %b", [MyStr("hello")])
- assert cur.fetchone() == ("hellogb",)
- cur = conn.execute("select %t", [MyStr("hello")])
- assert cur.fetchone() == ("hellogt",)
- conn.close()
+ with psycopg.connect(dsn) as conn:
+ cur = conn.execute("select %s", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogt",)
+ cur = conn.execute("select %b", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogb",)
+ cur = conn.execute("select %t", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogt",)
def test_dump_connection_ctx(conn):
assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is loader
+@pytest.mark.crdb("skip", reason="global adapters don't affect crdb")
def test_load_global_ctx(dsn, global_adapters):
psycopg.adapters.register_loader("text", make_loader("gt"))
psycopg.adapters.register_loader("text", make_bin_loader("gb"))
- conn = psycopg.connect(dsn)
- cur = conn.cursor(binary=False).execute("select 'hello'::text")
- assert cur.fetchone() == ("hellogt",)
- cur = conn.cursor(binary=True).execute("select 'hello'::text")
- assert cur.fetchone() == ("hellogb",)
- conn.close()
+ with psycopg.connect(dsn) as conn:
+ cur = conn.cursor(binary=False).execute("select 'hello'::text")
+ assert cur.fetchone() == ("hellogt",)
+ cur = conn.cursor(binary=True).execute("select 'hello'::text")
+ assert cur.fetchone() == ("hellogb",)
def test_load_connection_ctx(conn):
@pytest.mark.parametrize("fmt_out", pq.Format)
-def test_array_dumper(conn, fmt_out):
+def test_list_dumper(conn, fmt_out):
t = Transformer(conn)
fmt_in = PyFormat.from_pq(fmt_out)
dint = t.get_dumper([0], fmt_in)
assert dint.sub_dumper and dint.sub_dumper.oid == builtins["int2"].oid
dstr = t.get_dumper([""], fmt_in)
- if fmt_in == PyFormat.BINARY:
- assert isinstance(dstr, ListBinaryDumper)
- assert dstr.oid == builtins["text"].array_oid
- assert dstr.sub_dumper and dstr.sub_dumper.oid == builtins["text"].oid
- else:
- assert isinstance(dstr, ListDumper)
- assert dstr.oid == 0
- assert dstr.sub_dumper and dstr.sub_dumper.oid == 0
-
assert dstr is not dint
assert t.get_dumper([1], fmt_in) is dint
assert t.get_dumper(L, fmt_in)
+@pytest.mark.crdb("skip", reason="test in crdb test suite")
+def test_str_list_dumper_text(conn):
+ t = Transformer(conn)
+ dstr = t.get_dumper([""], PyFormat.TEXT)
+ assert isinstance(dstr, ListDumper)
+ assert dstr.oid == 0
+ assert dstr.sub_dumper and dstr.sub_dumper.oid == 0
+
+
+def test_str_list_dumper_binary(conn):
+ t = Transformer(conn)
+ dstr = t.get_dumper([""], PyFormat.BINARY)
+ assert isinstance(dstr, ListBinaryDumper)
+ assert dstr.oid == builtins["text"].array_oid
+ assert dstr.sub_dumper and dstr.sub_dumper.oid == builtins["text"].oid
+
+
def test_last_dumper_registered_ctx(conn):
cur = conn.cursor()
assert cur.fetchone()[0]
+@pytest.mark.crdb("skip", reason="test in crdb test suite")
@pytest.mark.parametrize("fmt_in", PyFormat)
def test_return_untyped(conn, fmt_in):
# Analyze and check for changes using strings in untyped/typed contexts
with pytest.raises(psycopg.NotSupportedError):
cur.execute("select 'x'")
+ def test_vendor(self, conn):
+ assert conn.info.vendor
+
@pytest.mark.parametrize(
"conninfo, want, env",
@pytest.fixture(scope="class")
-def with_dsn(request, dsn):
- request.cls.connect_args = (dsn,)
+def with_dsn(request, session_dsn):
+ request.cls.connect_args = (session_dsn,)
@pytest.mark.usefixtures("with_dsn")
return check_version(got, want, "libpq")
-def check_server_version(pgconn, want):
+def check_server_version(got, function):
"""
Verify if the server version is a version accepted.
and skips the test if the server version doesn't match what expected.
"""
- got = pgconn.server_version
+ want = function.want_pg_version
return check_version(got, want, "server")