ConnStatus = pq.ConnStatus
TransactionStatus = pq.TransactionStatus
- _info_class: Type[ConnectionInfo] = ConnectionInfo
-
def __init__(self, pgconn: "PGconn"):
self.pgconn = pgconn
self._autocommit = False
@property
def info(self) -> ConnectionInfo:
"""A `ConnectionInfo` attribute to inspect connection properties."""
- return self._info_class(self.pgconn)
+ return ConnectionInfo(self.pgconn)
@property
def adapters(self) -> AdaptersMap:
rv.cursor_factory = cursor_factory
if context:
rv._adapters = AdaptersMap(context.adapters)
-
rv.prepare_threshold = prepare_threshold
-
- # TODOCRDB find the right place for this operation
- if rv.pgconn.parameter_status(b"crdb_version"):
- from .crdb import customize_crdb_connection
-
- customize_crdb_connection(rv)
-
return rv
def __enter__(self: _Self) -> _Self:
if context:
rv._adapters = AdaptersMap(context.adapters)
rv.prepare_threshold = prepare_threshold
-
- # TODOCRDB find the right place for this operation
- if rv.pgconn.parameter_status(b"crdb_version"):
- from .crdb import customize_crdb_connection
-
- customize_crdb_connection(rv)
-
return rv
async def __aenter__(self: _Self) -> _Self:
import re
from enum import Enum
from typing import Any, Optional, Union, TYPE_CHECKING
+from ._typeinfo import TypeInfo, TypesRegistry
from . import errors as e
from .abc import AdaptContext
-from .postgres import adapters as pg_adapters, TEXT_OID
+from .rows import Row
+from .postgres import TEXT_OID
from .conninfo import ConnectionInfo
+from .connection import Connection
from ._adapters_map import AdaptersMap
+from .connection_async import AsyncConnection
from .types.enum import EnumDumper, EnumBinaryDumper
-adapters = AdaptersMap(pg_adapters)
-
if TYPE_CHECKING:
- from .connection import Connection
- from .connection_async import AsyncConnection
+ from .pq.abc import PGconn
+
+types = TypesRegistry()
+
+# Global adapter maps with PostgreSQL types configuration
+adapters = AdaptersMap(types=types)
+
+
+class _CrdbConnectionMixin:
+
+ _adapters: Optional[AdaptersMap]
+ pgconn: "PGconn"
+
+ @classmethod
+ def is_crdb(
+ cls, conn: Union[Connection[Any], AsyncConnection[Any], "PGconn"]
+ ) -> bool:
+ """
+ Return True if the server connected to ``conn`` is CockroachDB.
+ """
+ if isinstance(conn, (Connection, AsyncConnection)):
+ conn = conn.pgconn
+
+ return bool(conn.parameter_status(b"crdb_version"))
+
+ @property
+ def adapters(self) -> AdaptersMap:
+ if not self._adapters:
+ # By default, use CockroachDB adapters map
+ self._adapters = AdaptersMap(adapters)
+
+ return self._adapters
+
+ @property
+ def info(self) -> "CrdbConnectionInfo":
+ return CrdbConnectionInfo(self.pgconn)
+
+
+class CrdbConnection(_CrdbConnectionMixin, Connection[Row]):
+ pass
+
+
+class AsyncCrdbConnection(_CrdbConnectionMixin, AsyncConnection[Row]):
+ pass
class CrdbConnectionInfo(ConnectionInfo):
oid = TEXT_OID
+def register_postgres_adapters(context: AdaptContext) -> None:
+ # Same adapters used by PostgreSQL, or a good starting point for customization
+
+ from .types import array, bool, datetime, json, none, numeric, string, uuid
+
+ array.register_default_adapters(context)
+ bool.register_default_adapters(context)
+ datetime.register_default_adapters(context)
+ json.register_default_adapters(context)
+ none.register_default_adapters(context)
+ numeric.register_default_adapters(context)
+ string.register_default_adapters(context)
+ uuid.register_default_adapters(context)
+
+
def register_crdb_adapters(context: AdaptContext) -> None:
- from .types import string, json
+ from .types import array
+
+ register_postgres_adapters(context)
- adapters = context.adapters
+ # String must come after enum to map text oid -> string dumper
+ register_crdb_enum_adapters(context)
+ register_crdb_string_adapters(context)
+ register_crdb_json_adapters(context)
+ register_crdb_net_adapters(context)
+
+ array.register_all_arrays(adapters)
+
+
+def register_crdb_string_adapters(context: AdaptContext) -> None:
+ from .types import string
# Dump strings with text oid instead of unknown.
# Unlike PostgreSQL, CRDB seems able to cast text to most types.
- adapters.register_dumper(str, string.StrDumper)
- adapters.register_dumper(Enum, CrdbEnumBinaryDumper)
- adapters.register_dumper(Enum, CrdbEnumDumper)
+ context.adapters.register_dumper(str, string.StrDumper)
+ context.adapters.register_dumper(str, string.StrBinaryDumper)
- # CRDB doesn't have json/jsonb: both dump as the jsonb oid
- adapters.register_dumper(json.Json, json.JsonbBinaryDumper)
- adapters.register_dumper(json.Json, json.JsonbDumper)
+def register_crdb_enum_adapters(context: AdaptContext) -> None:
+ context.adapters.register_dumper(Enum, CrdbEnumBinaryDumper)
+ context.adapters.register_dumper(Enum, CrdbEnumDumper)
-register_crdb_adapters(adapters)
+def register_crdb_json_adapters(context: AdaptContext) -> None:
+ from .types import json
+
+ # CRDB doesn't have json/jsonb: both dump as the jsonb oid
+ context.adapters.register_dumper(json.Json, json.JsonbBinaryDumper)
+ context.adapters.register_dumper(json.Json, json.JsonbDumper)
+
+
+def register_crdb_net_adapters(context: AdaptContext) -> None:
+ from psycopg.types import net
+
+ context.adapters.register_dumper("ipaddress.IPv4Address", net.InterfaceDumper)
+ context.adapters.register_dumper("ipaddress.IPv6Address", net.InterfaceDumper)
+ context.adapters.register_dumper("ipaddress.IPv4Interface", net.InterfaceDumper)
+ context.adapters.register_dumper("ipaddress.IPv6Interface", net.InterfaceDumper)
+ context.adapters.register_dumper("ipaddress.IPv4Address", net.AddressBinaryDumper)
+ context.adapters.register_dumper("ipaddress.IPv6Address", net.AddressBinaryDumper)
+ context.adapters.register_dumper(
+ "ipaddress.IPv4Interface", net.InterfaceBinaryDumper
+ )
+ context.adapters.register_dumper(
+ "ipaddress.IPv6Interface", net.InterfaceBinaryDumper
+ )
+ context.adapters.register_dumper(None, net.InetBinaryDumper)
+ context.adapters.register_loader("inet", net.InetLoader)
+ context.adapters.register_loader("inet", net.InetBinaryLoader)
+
+
+for t in [
+ TypeInfo("json", 3802, 3807, regtype="jsonb"), # Alias json -> jsonb.
+ TypeInfo("int8", 20, 1016, regtype="integer"), # Alias integer -> int8
+ TypeInfo('"char"', 18, 1002), # special case, not generated
+ # autogenerated: start
+ # Generated from CockroachDB 22.1.0
+ TypeInfo("bit", 1560, 1561),
+ TypeInfo("bool", 16, 1000, regtype="boolean"),
+ TypeInfo("bpchar", 1042, 1014, regtype="character"),
+ TypeInfo("bytea", 17, 1001),
+ TypeInfo("date", 1082, 1182),
+ TypeInfo("float4", 700, 1021, regtype="real"),
+ TypeInfo("float8", 701, 1022, regtype="'double precision'"),
+ TypeInfo("inet", 869, 1041),
+ TypeInfo("int2", 21, 1005, regtype="smallint"),
+ TypeInfo("int2vector", 22, 1006),
+ TypeInfo("int4", 23, 1007),
+ TypeInfo("int8", 20, 1016, regtype="bigint"),
+ TypeInfo("interval", 1186, 1187),
+ TypeInfo("jsonb", 3802, 3807),
+ TypeInfo("name", 19, 1003),
+ TypeInfo("numeric", 1700, 1231),
+ TypeInfo("oid", 26, 1028),
+ TypeInfo("oidvector", 30, 1013),
+ TypeInfo("record", 2249, 2287),
+ TypeInfo("regclass", 2205, 2210),
+ TypeInfo("regnamespace", 4089, 4090),
+ TypeInfo("regproc", 24, 1008),
+ TypeInfo("regprocedure", 2202, 2207),
+ TypeInfo("regrole", 4096, 4097),
+ TypeInfo("regtype", 2206, 2211),
+ TypeInfo("text", 25, 1009),
+ TypeInfo("time", 1083, 1183, regtype="'time without time zone'"),
+ TypeInfo("timestamp", 1114, 1115, regtype="'timestamp without time zone'"),
+ TypeInfo("timestamptz", 1184, 1185, regtype="'timestamp with time zone'"),
+ TypeInfo("timetz", 1266, 1270, regtype="'time with time zone'"),
+ TypeInfo("unknown", 705, 0),
+ TypeInfo("uuid", 2950, 2951),
+ TypeInfo("varbit", 1562, 1563, regtype="'bit varying'"),
+ TypeInfo("varchar", 1043, 1015, regtype="'character varying'"),
+ # autogenerated: end
+]:
+ types.add(t)
-def customize_crdb_connection(
- conn: "Union[Connection[Any], AsyncConnection[Any]]",
-) -> None:
- conn._info_class = CrdbConnectionInfo
- # TODOCRDB: what if someone is passing context? they will have
- # customised the postgres adapters, so those changes wouldn't apply
- # to crdb (e.g. the Django backend in preparation).
- if conn._adapters is None:
- # Not customized by connect()
- conn._adapters = AdaptersMap(adapters)
+register_crdb_adapters(adapters)
import pytest
-import psycopg.crdb
+from psycopg.crdb import adapters, CrdbConnection
+
from psycopg.adapt import PyFormat, Transformer
from psycopg.types.array import ListDumper
from psycopg.postgres import types as builtins
@pytest.fixture
def crdb_adapters():
"""Restore the crdb adapters after a test has changed them."""
- from psycopg.crdb import adapters
-
dumpers = deepcopy(adapters._dumpers)
dumpers_by_oid = deepcopy(adapters._dumpers_by_oid)
loaders = deepcopy(adapters._loaders)
def test_dump_global_ctx(dsn, crdb_adapters, pgconn):
- psycopg.crdb.adapters.register_dumper(MyStr, make_bin_dumper("gb"))
- psycopg.crdb.adapters.register_dumper(MyStr, make_dumper("gt"))
- with psycopg.connect(dsn) as conn:
+ adapters.register_dumper(MyStr, make_bin_dumper("gb"))
+ adapters.register_dumper(MyStr, make_dumper("gt"))
+ with CrdbConnection.connect(dsn) as conn:
cur = conn.execute("select %s", [MyStr("hello")])
assert cur.fetchone() == ("hellogt",)
cur = conn.execute("select %b", [MyStr("hello")])
def test_load_global_ctx(dsn, crdb_adapters):
- psycopg.crdb.adapters.register_loader("text", make_loader("gt"))
- psycopg.crdb.adapters.register_loader("text", make_bin_loader("gb"))
- with psycopg.connect(dsn) as conn:
+ adapters.register_loader("text", make_loader("gt"))
+ adapters.register_loader("text", make_bin_loader("gb"))
+ with CrdbConnection.connect(dsn) as conn:
cur = conn.cursor(binary=False).execute("select 'hello'::text")
assert cur.fetchone() == ("hellogt",)
cur = conn.cursor(binary=True).execute("select 'hello'::text")
--- /dev/null
+from psycopg.crdb import CrdbConnection
+
+import pytest
+
+pytestmark = pytest.mark.crdb
+
+
+def test_is_crdb(conn):
+ assert CrdbConnection.is_crdb(conn)
+ assert CrdbConnection.is_crdb(conn.pgconn)
--- /dev/null
+from psycopg.crdb import CrdbConnection
+
+import pytest
+
+pytestmark = pytest.mark.crdb("skip")
+
+
+def test_is_crdb(conn):
+ assert not CrdbConnection.is_crdb(conn)
+ assert not CrdbConnection.is_crdb(conn.pgconn)
import pytest
from .utils import check_version
+from psycopg.crdb import CrdbConnection
def pytest_configure(config):
# Utility functions which can be imported in the test suite
-
-def is_crdb(conn):
- if hasattr(conn, "pgconn"):
- conn = conn.pgconn
-
- return bool(conn.parameter_status(b"crdb_version"))
+is_crdb = CrdbConnection.is_crdb
def crdb_skip_message(reason):
import psycopg
from psycopg import pq
+from psycopg import sql
+from psycopg.crdb import CrdbConnection
from .utils import check_libpq_version, check_server_version
"""Return a PGconn connection open to `--test-dsn`."""
check_connection_version(request.function)
- from psycopg import pq
-
conn = pq.PGconn.connect(dsn.encode())
if conn.status != pq.ConnStatus.OK:
pytest.fail(f"bad connection: {conn.error_message.decode('utf8', 'replace')}")
"""Return a `Connection` connected to the ``--test-dsn`` database."""
check_connection_version(request.function)
- from psycopg import Connection
+ cls = psycopg.Connection
+ if crdb_version:
+ cls = CrdbConnection
- conn = Connection.connect(dsn)
+ conn = cls.connect(dsn)
with maybe_trace(conn.pgconn, tracefile, request.function):
yield conn
conn.close()
"""Return an `AsyncConnection` connected to the ``--test-dsn`` database."""
check_connection_version(request.function)
- from psycopg import AsyncConnection
+ cls = psycopg.AsyncConnection
+ if crdb_version:
+ from psycopg.crdb import AsyncCrdbConnection
+
+ cls = AsyncCrdbConnection
- conn = await AsyncConnection.connect(dsn)
+ conn = await cls.connect(dsn)
with maybe_trace(conn.pgconn, tracefile, request.function):
yield conn
await conn.close()
"""
Return a session `Connection` connected to the ``--test-dsn`` database.
"""
- from psycopg import Connection
-
- conn = Connection.connect(session_dsn, autocommit=True)
+ conn = psycopg.Connection.connect(session_dsn, autocommit=True)
yield conn
conn.close()
def patch_exec(conn, monkeypatch):
"""Helper to implement the commands fixture both sync and async."""
- from psycopg import sql
-
_orig_exec_command = conn._exec_command
L = ListPopAll()
@pytest.fixture
def hstore(svcconn):
- from psycopg import Error
-
try:
with svcconn.transaction():
svcconn.execute("create extension if not exists hstore")
- except Error as e:
+ except psycopg.Error as e:
pytest.skip(str(e))
global pg_version, crdb_version
- import psycopg
-
with psycopg.connect(dsn, connect_timeout=10) as conn:
conn.execute("select 1")
pg_version = conn.info.server_version
- if conn.info.vendor == "CockroachDB":
- crdb_version = conn.info.crdb_version # type: ignore
- else:
- crdb_version = None
+
+ crdb_version = None
+ param = conn.info.parameter_status("crdb_version")
+ if param:
+ from psycopg.crdb import CrdbConnectionInfo
+
+ crdb_version = CrdbConnectionInfo.parse_crdb_version(param)
"""
Update the maps of builtin types and names.
-You can update this file by executing it, using the PG* env var to connect
+This script updates some of the files in psycopg source code with data read
+from a database catalog.
"""
import re
+import argparse
import subprocess as sp
from typing import List
from pathlib import Path
+import psycopg
+from psycopg.rows import TupleRow
+from psycopg.crdb import CrdbConnection
+from psycopg._compat import TypeAlias
+
+Connection: TypeAlias = psycopg.Connection[TupleRow]
+
ROOT = Path(__file__).parent.parent
-version_sql = """
-select format($$
-# Generated from PostgreSQL %s.%s
-$$,
- setting::int / 10000, setting::int % 100) -- assume PG >= 10
- from pg_settings
- where name = 'server_version_num'
-"""
+def main() -> None:
+ opt = parse_cmdline()
+ conn = psycopg.connect(opt.dsn, autocommit=True)
+
+ if CrdbConnection.is_crdb(conn):
+ conn = CrdbConnection.connect(opt.dsn, autocommit=True)
+ update_crdb_python_oids(conn)
+ else:
+ update_python_oids(conn)
+ update_cython_oids(conn)
+
+
+def update_python_oids(conn: Connection) -> None:
+ fn = ROOT / "psycopg/psycopg/postgres.py"
+
+ lines = []
+ lines.extend(get_version_comment(conn))
+ lines.extend(get_py_types(conn))
+ lines.extend(get_py_ranges(conn))
+ lines.extend(get_py_multiranges(conn))
+
+ update_file(fn, lines)
+ sp.check_call(["black", "-q", fn])
+
+
+def update_cython_oids(conn: Connection) -> None:
+ fn = ROOT / "psycopg_c/psycopg_c/_psycopg/oids.pxd"
+
+ lines = []
+ lines.extend(get_version_comment(conn))
+ lines.extend(get_cython_oids(conn))
+
+ update_file(fn, lines)
+
+
+def update_crdb_python_oids(conn: Connection) -> None:
+ fn = ROOT / "psycopg/psycopg/crdb.py"
+
+ lines = []
+ lines.extend(get_version_comment(conn))
+ lines.extend(get_py_types(conn))
+
+ update_file(fn, lines)
+ sp.check_call(["black", "-q", fn])
+
-# Note: "record" is a pseudotype but still a useful one to have.
-# "pg_lsn" is a documented public type and useful in streaming replication
-# treat "char" (with quotes) separately.
-py_types_sql = """
-select
- 'TypeInfo('
- || array_to_string(array_remove(array[
- format('%L', typname),
- oid::text,
- typarray::text,
- case when oid::regtype::text != typname
- then format('regtype=%L', oid::regtype)
- end,
- case when typdelim != ','
- then format('delimiter=%L', typdelim)
- end
- ], null), ',')
- || '),'
+def get_version_comment(conn: Connection) -> List[str]:
+ if conn.info.vendor == "PostgreSQL":
+ # Assume PG > 10
+ num = conn.info.server_version
+ version = f"{num // 10000}.{num % 100}"
+ elif conn.info.vendor == "CockroachDB":
+ assert isinstance(conn, CrdbConnection)
+ num = conn.info.crdb_version
+ version = f"{num // 10000}.{num % 10000 // 100}.{num % 100}"
+ else:
+ raise NotImplementedError(f"unexpected vendor: {conn.info.vendor}")
+ return ["", f" # Generated from {conn.info.vendor} {version}", ""]
+
+
+def get_py_types(conn: Connection) -> List[str]:
+ # Note: "record" is a pseudotype but still a useful one to have.
+ # "pg_lsn" is a documented public type and useful in streaming replication
+ lines = []
+ for (typname, oid, typarray, regtype, typdelim) in conn.execute(
+ """
+select typname, oid, typarray, typname::regtype::text as regtype, typdelim
from pg_type t
where
oid < 10000
and (typname !~ '^(_|pg_)' or typname = 'pg_lsn')
order by typname
"""
-
-py_ranges_sql = """
-select
- format('RangeInfo(%L, %s, %s, subtype_oid=%s),',
- typname, oid, typarray, rngsubtype)
+ ):
+ # Weird legacy type in postgres catalog
+ if typname == "char":
+ typname = regtype = '"char"'
+
+ # https://github.com/cockroachdb/cockroach/issues/81645
+ if typname == "int4" and conn.info.vendor == "CockroachDB":
+ regtype = typname
+
+ params = [f"{typname!r}, {oid}, {typarray}"]
+ if regtype != typname:
+ params.append(f"regtype={regtype!r}")
+ if typdelim != ",":
+ params.append(f"delimiter={typdelim!r}")
+ lines.append(f"TypeInfo({','.join(params)}),")
+
+ return lines
+
+
+def get_py_ranges(conn: Connection) -> List[str]:
+ lines = []
+ for (typname, oid, typarray, rngsubtype) in conn.execute(
+ """
+select typname, oid, typarray, rngsubtype
from
pg_type t
join pg_range r on t.oid = rngtypid
where
oid < 10000
and typtype = 'r'
- and (typname !~ '^(_|pg_)' or typname = 'pg_lsn')
order by typname
"""
+ ):
+ params = [f"{typname!r}, {oid}, {typarray}, subtype_oid={rngsubtype}"]
+ lines.append(f"RangeInfo({','.join(params)}),")
+
+ return lines
-py_multiranges_sql = """
-select
- format('MultirangeInfo(%L, %s, %s, range_oid=%s, subtype_oid=%s),',
- typname, oid, typarray, rngtypid, rngsubtype)
+
+def get_py_multiranges(conn: Connection) -> List[str]:
+ lines = []
+ for (typname, oid, typarray, rngtypid, rngsubtype) in conn.execute(
+ """
+select typname, oid, typarray, rngtypid, rngsubtype
from
pg_type t
join pg_range r on t.oid = rngmultitypid
where
oid < 10000
and typtype = 'm'
- and (typname !~ '^(_|pg_)' or typname = 'pg_lsn')
order by typname
"""
+ ):
+ params = [
+ f"{typname!r}, {oid}, {typarray},"
+ f" range_oid={rngtypid}, subtype_oid={rngsubtype}"
+ ]
+ lines.append(f"MultirangeInfo({','.join(params)}),")
+
+ return lines
-cython_oids_sql = """
-select format('%s_OID = %s', upper(typname), oid)
+
+def get_cython_oids(conn: Connection) -> List[str]:
+ lines = []
+ for (typname, oid) in conn.execute(
+ """
+select typname, oid
from pg_type
where
oid < 10000
and (typname !~ '^(_|pg_)' or typname = 'pg_lsn')
order by typname
"""
+ ):
+ const_name = typname.upper() + "_OID"
+ lines.append(f" {const_name} = {oid}")
-
-def update_python_oids() -> None:
- queries = [version_sql, py_types_sql, py_ranges_sql, py_multiranges_sql]
- fn = ROOT / "psycopg/psycopg/postgres.py"
- update_file(fn, queries)
- sp.check_call(["black", "-q", fn])
-
-
-def update_cython_oids() -> None:
- queries = [version_sql, cython_oids_sql]
- fn = ROOT / "psycopg_c/psycopg_c/_psycopg/oids.pxd"
- update_file(fn, queries)
+ return lines
-def update_file(fn: Path, queries: List[str]) -> None:
- with fn.open("rb") as f:
+def update_file(fn: Path, new: List[str]) -> None:
+ with fn.open("r") as f:
lines = f.read().splitlines()
-
- new = []
- for query in queries:
- out = sp.run(["psql", "-AXqt", "-c", query], stdout=sp.PIPE, check=True)
- new.extend(out.stdout.splitlines())
-
- new = [b" " * 4 + line if line else b"" for line in new] # indent
istart, iend = [
i
for i, line in enumerate(lines)
- if re.match(rb"\s*#\s*autogenerated:\s+(start|end)", line)
+ if re.match(r"\s*#\s*autogenerated:\s+(start|end)", line)
]
lines[istart + 1 : iend] = new
- with fn.open("wb") as f:
- f.write(b"\n".join(lines))
- f.write(b"\n")
+ with fn.open("w") as f:
+ f.write("\n".join(lines))
+ f.write("\n")
+
+
+def parse_cmdline() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument("dsn", help="where to connect to")
+ opt = parser.parse_args()
+ return opt
if __name__ == "__main__":
- update_python_oids()
- update_cython_oids()
+ main()