From: Daniele Varrazzo Date: Mon, 23 May 2022 00:37:58 +0000 (+0200) Subject: feat(crdb): add CrdbConnection class X-Git-Tag: 3.1~49^2~54 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f29628d87d5636eb0c99760deb0624ed2a017171;p=thirdparty%2Fpsycopg.git feat(crdb): add CrdbConnection class Drop the automatic detection of crdb. It seems dangerous to change the program types based on external factors such as the database connected to. Generate and use a crdb-specific types oids map, on which to build the adapters map. --- diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index 7937c4552..abd7149f4 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -114,8 +114,6 @@ class BaseConnection(Generic[Row]): ConnStatus = pq.ConnStatus TransactionStatus = pq.TransactionStatus - _info_class: Type[ConnectionInfo] = ConnectionInfo - def __init__(self, pgconn: "PGconn"): self.pgconn = pgconn self._autocommit = False @@ -285,7 +283,7 @@ class BaseConnection(Generic[Row]): @property def info(self) -> ConnectionInfo: """A `ConnectionInfo` attribute to inspect connection properties.""" - return self._info_class(self.pgconn) + return ConnectionInfo(self.pgconn) @property def adapters(self) -> AdaptersMap: @@ -740,15 +738,7 @@ class Connection(BaseConnection[Row]): rv.cursor_factory = cursor_factory if context: rv._adapters = AdaptersMap(context.adapters) - rv.prepare_threshold = prepare_threshold - - # TODOCRDB find the right place for this operation - if rv.pgconn.parameter_status(b"crdb_version"): - from .crdb import customize_crdb_connection - - customize_crdb_connection(rv) - return rv def __enter__(self: _Self) -> _Self: diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 334e60dde..1545667b7 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -136,13 +136,6 @@ class AsyncConnection(BaseConnection[Row]): if context: rv._adapters = AdaptersMap(context.adapters) rv.prepare_threshold = prepare_threshold - - # TODOCRDB find the right place for this operation - if rv.pgconn.parameter_status(b"crdb_version"): - from .crdb import customize_crdb_connection - - customize_crdb_connection(rv) - return rv async def __aenter__(self: _Self) -> _Self: diff --git a/psycopg/psycopg/crdb.py b/psycopg/psycopg/crdb.py index 64f74c618..051e74075 100644 --- a/psycopg/psycopg/crdb.py +++ b/psycopg/psycopg/crdb.py @@ -7,19 +7,63 @@ Types configuration specific for CockroachDB. import re from enum import Enum from typing import Any, Optional, Union, TYPE_CHECKING +from ._typeinfo import TypeInfo, TypesRegistry from . import errors as e from .abc import AdaptContext -from .postgres import adapters as pg_adapters, TEXT_OID +from .rows import Row +from .postgres import TEXT_OID from .conninfo import ConnectionInfo +from .connection import Connection from ._adapters_map import AdaptersMap +from .connection_async import AsyncConnection from .types.enum import EnumDumper, EnumBinaryDumper -adapters = AdaptersMap(pg_adapters) - if TYPE_CHECKING: - from .connection import Connection - from .connection_async import AsyncConnection + from .pq.abc import PGconn + +types = TypesRegistry() + +# Global adapter maps with PostgreSQL types configuration +adapters = AdaptersMap(types=types) + + +class _CrdbConnectionMixin: + + _adapters: Optional[AdaptersMap] + pgconn: "PGconn" + + @classmethod + def is_crdb( + cls, conn: Union[Connection[Any], AsyncConnection[Any], "PGconn"] + ) -> bool: + """ + Return True if the server connected to ``conn`` is CockroachDB. + """ + if isinstance(conn, (Connection, AsyncConnection)): + conn = conn.pgconn + + return bool(conn.parameter_status(b"crdb_version")) + + @property + def adapters(self) -> AdaptersMap: + if not self._adapters: + # By default, use CockroachDB adapters map + self._adapters = AdaptersMap(adapters) + + return self._adapters + + @property + def info(self) -> "CrdbConnectionInfo": + return CrdbConnectionInfo(self.pgconn) + + +class CrdbConnection(_CrdbConnectionMixin, Connection[Row]): + pass + + +class AsyncCrdbConnection(_CrdbConnectionMixin, AsyncConnection[Row]): + pass class CrdbConnectionInfo(ConnectionInfo): @@ -65,33 +109,120 @@ class CrdbEnumBinaryDumper(EnumBinaryDumper): oid = TEXT_OID +def register_postgres_adapters(context: AdaptContext) -> None: + # Same adapters used by PostgreSQL, or a good starting point for customization + + from .types import array, bool, datetime, json, none, numeric, string, uuid + + array.register_default_adapters(context) + bool.register_default_adapters(context) + datetime.register_default_adapters(context) + json.register_default_adapters(context) + none.register_default_adapters(context) + numeric.register_default_adapters(context) + string.register_default_adapters(context) + uuid.register_default_adapters(context) + + def register_crdb_adapters(context: AdaptContext) -> None: - from .types import string, json + from .types import array + + register_postgres_adapters(context) - adapters = context.adapters + # String must come after enum to map text oid -> string dumper + register_crdb_enum_adapters(context) + register_crdb_string_adapters(context) + register_crdb_json_adapters(context) + register_crdb_net_adapters(context) + + array.register_all_arrays(adapters) + + +def register_crdb_string_adapters(context: AdaptContext) -> None: + from .types import string # Dump strings with text oid instead of unknown. # Unlike PostgreSQL, CRDB seems able to cast text to most types. - adapters.register_dumper(str, string.StrDumper) - adapters.register_dumper(Enum, CrdbEnumBinaryDumper) - adapters.register_dumper(Enum, CrdbEnumDumper) + context.adapters.register_dumper(str, string.StrDumper) + context.adapters.register_dumper(str, string.StrBinaryDumper) - # CRDB doesn't have json/jsonb: both dump as the jsonb oid - adapters.register_dumper(json.Json, json.JsonbBinaryDumper) - adapters.register_dumper(json.Json, json.JsonbDumper) +def register_crdb_enum_adapters(context: AdaptContext) -> None: + context.adapters.register_dumper(Enum, CrdbEnumBinaryDumper) + context.adapters.register_dumper(Enum, CrdbEnumDumper) -register_crdb_adapters(adapters) +def register_crdb_json_adapters(context: AdaptContext) -> None: + from .types import json + + # CRDB doesn't have json/jsonb: both dump as the jsonb oid + context.adapters.register_dumper(json.Json, json.JsonbBinaryDumper) + context.adapters.register_dumper(json.Json, json.JsonbDumper) + + +def register_crdb_net_adapters(context: AdaptContext) -> None: + from psycopg.types import net + + context.adapters.register_dumper("ipaddress.IPv4Address", net.InterfaceDumper) + context.adapters.register_dumper("ipaddress.IPv6Address", net.InterfaceDumper) + context.adapters.register_dumper("ipaddress.IPv4Interface", net.InterfaceDumper) + context.adapters.register_dumper("ipaddress.IPv6Interface", net.InterfaceDumper) + context.adapters.register_dumper("ipaddress.IPv4Address", net.AddressBinaryDumper) + context.adapters.register_dumper("ipaddress.IPv6Address", net.AddressBinaryDumper) + context.adapters.register_dumper( + "ipaddress.IPv4Interface", net.InterfaceBinaryDumper + ) + context.adapters.register_dumper( + "ipaddress.IPv6Interface", net.InterfaceBinaryDumper + ) + context.adapters.register_dumper(None, net.InetBinaryDumper) + context.adapters.register_loader("inet", net.InetLoader) + context.adapters.register_loader("inet", net.InetBinaryLoader) + + +for t in [ + TypeInfo("json", 3802, 3807, regtype="jsonb"), # Alias json -> jsonb. + TypeInfo("int8", 20, 1016, regtype="integer"), # Alias integer -> int8 + TypeInfo('"char"', 18, 1002), # special case, not generated + # autogenerated: start + # Generated from CockroachDB 22.1.0 + TypeInfo("bit", 1560, 1561), + TypeInfo("bool", 16, 1000, regtype="boolean"), + TypeInfo("bpchar", 1042, 1014, regtype="character"), + TypeInfo("bytea", 17, 1001), + TypeInfo("date", 1082, 1182), + TypeInfo("float4", 700, 1021, regtype="real"), + TypeInfo("float8", 701, 1022, regtype="'double precision'"), + TypeInfo("inet", 869, 1041), + TypeInfo("int2", 21, 1005, regtype="smallint"), + TypeInfo("int2vector", 22, 1006), + TypeInfo("int4", 23, 1007), + TypeInfo("int8", 20, 1016, regtype="bigint"), + TypeInfo("interval", 1186, 1187), + TypeInfo("jsonb", 3802, 3807), + TypeInfo("name", 19, 1003), + TypeInfo("numeric", 1700, 1231), + TypeInfo("oid", 26, 1028), + TypeInfo("oidvector", 30, 1013), + TypeInfo("record", 2249, 2287), + TypeInfo("regclass", 2205, 2210), + TypeInfo("regnamespace", 4089, 4090), + TypeInfo("regproc", 24, 1008), + TypeInfo("regprocedure", 2202, 2207), + TypeInfo("regrole", 4096, 4097), + TypeInfo("regtype", 2206, 2211), + TypeInfo("text", 25, 1009), + TypeInfo("time", 1083, 1183, regtype="'time without time zone'"), + TypeInfo("timestamp", 1114, 1115, regtype="'timestamp without time zone'"), + TypeInfo("timestamptz", 1184, 1185, regtype="'timestamp with time zone'"), + TypeInfo("timetz", 1266, 1270, regtype="'time with time zone'"), + TypeInfo("unknown", 705, 0), + TypeInfo("uuid", 2950, 2951), + TypeInfo("varbit", 1562, 1563, regtype="'bit varying'"), + TypeInfo("varchar", 1043, 1015, regtype="'character varying'"), + # autogenerated: end +]: + types.add(t) -def customize_crdb_connection( - conn: "Union[Connection[Any], AsyncConnection[Any]]", -) -> None: - conn._info_class = CrdbConnectionInfo - # TODOCRDB: what if someone is passing context? they will have - # customised the postgres adapters, so those changes wouldn't apply - # to crdb (e.g. the Django backend in preparation). - if conn._adapters is None: - # Not customized by connect() - conn._adapters = AdaptersMap(adapters) +register_crdb_adapters(adapters) diff --git a/tests/crdb/test_adapt.py b/tests/crdb/test_adapt.py index 2b64e3ff6..ce5bacf9f 100644 --- a/tests/crdb/test_adapt.py +++ b/tests/crdb/test_adapt.py @@ -2,7 +2,8 @@ from copy import deepcopy import pytest -import psycopg.crdb +from psycopg.crdb import adapters, CrdbConnection + from psycopg.adapt import PyFormat, Transformer from psycopg.types.array import ListDumper from psycopg.postgres import types as builtins @@ -40,8 +41,6 @@ def test_str_list_dumper_text(conn): @pytest.fixture def crdb_adapters(): """Restore the crdb adapters after a test has changed them.""" - from psycopg.crdb import adapters - dumpers = deepcopy(adapters._dumpers) dumpers_by_oid = deepcopy(adapters._dumpers_by_oid) loaders = deepcopy(adapters._loaders) @@ -58,9 +57,9 @@ def crdb_adapters(): def test_dump_global_ctx(dsn, crdb_adapters, pgconn): - psycopg.crdb.adapters.register_dumper(MyStr, make_bin_dumper("gb")) - psycopg.crdb.adapters.register_dumper(MyStr, make_dumper("gt")) - with psycopg.connect(dsn) as conn: + adapters.register_dumper(MyStr, make_bin_dumper("gb")) + adapters.register_dumper(MyStr, make_dumper("gt")) + with CrdbConnection.connect(dsn) as conn: cur = conn.execute("select %s", [MyStr("hello")]) assert cur.fetchone() == ("hellogt",) cur = conn.execute("select %b", [MyStr("hello")]) @@ -70,9 +69,9 @@ def test_dump_global_ctx(dsn, crdb_adapters, pgconn): def test_load_global_ctx(dsn, crdb_adapters): - psycopg.crdb.adapters.register_loader("text", make_loader("gt")) - psycopg.crdb.adapters.register_loader("text", make_bin_loader("gb")) - with psycopg.connect(dsn) as conn: + adapters.register_loader("text", make_loader("gt")) + adapters.register_loader("text", make_bin_loader("gb")) + with CrdbConnection.connect(dsn) as conn: cur = conn.cursor(binary=False).execute("select 'hello'::text") assert cur.fetchone() == ("hellogt",) cur = conn.cursor(binary=True).execute("select 'hello'::text") diff --git a/tests/crdb/test_connection.py b/tests/crdb/test_connection.py new file mode 100644 index 000000000..9c169e3a3 --- /dev/null +++ b/tests/crdb/test_connection.py @@ -0,0 +1,10 @@ +from psycopg.crdb import CrdbConnection + +import pytest + +pytestmark = pytest.mark.crdb + + +def test_is_crdb(conn): + assert CrdbConnection.is_crdb(conn) + assert CrdbConnection.is_crdb(conn.pgconn) diff --git a/tests/crdb/test_no_crdb.py b/tests/crdb/test_no_crdb.py new file mode 100644 index 000000000..ac1bc1897 --- /dev/null +++ b/tests/crdb/test_no_crdb.py @@ -0,0 +1,10 @@ +from psycopg.crdb import CrdbConnection + +import pytest + +pytestmark = pytest.mark.crdb("skip") + + +def test_is_crdb(conn): + assert not CrdbConnection.is_crdb(conn) + assert not CrdbConnection.is_crdb(conn.pgconn) diff --git a/tests/fix_crdb.py b/tests/fix_crdb.py index 01619c5f3..15243fa92 100644 --- a/tests/fix_crdb.py +++ b/tests/fix_crdb.py @@ -1,6 +1,7 @@ import pytest from .utils import check_version +from psycopg.crdb import CrdbConnection def pytest_configure(config): @@ -57,12 +58,7 @@ def check_crdb_version(got, func): # Utility functions which can be imported in the test suite - -def is_crdb(conn): - if hasattr(conn, "pgconn"): - conn = conn.pgconn - - return bool(conn.parameter_status(b"crdb_version")) +is_crdb = CrdbConnection.is_crdb def crdb_skip_message(reason): diff --git a/tests/fix_db.py b/tests/fix_db.py index 69ec20b17..d5f9e4a1a 100644 --- a/tests/fix_db.py +++ b/tests/fix_db.py @@ -6,6 +6,8 @@ from typing import List, Optional import psycopg from psycopg import pq +from psycopg import sql +from psycopg.crdb import CrdbConnection from .utils import check_libpq_version, check_server_version @@ -143,8 +145,6 @@ def pgconn(dsn, request, tracefile): """Return a PGconn connection open to `--test-dsn`.""" check_connection_version(request.function) - from psycopg import pq - conn = pq.PGconn.connect(dsn.encode()) if conn.status != pq.ConnStatus.OK: pytest.fail(f"bad connection: {conn.error_message.decode('utf8', 'replace')}") @@ -160,9 +160,11 @@ def conn(dsn, request, tracefile): """Return a `Connection` connected to the ``--test-dsn`` database.""" check_connection_version(request.function) - from psycopg import Connection + cls = psycopg.Connection + if crdb_version: + cls = CrdbConnection - conn = Connection.connect(dsn) + conn = cls.connect(dsn) with maybe_trace(conn.pgconn, tracefile, request.function): yield conn conn.close() @@ -186,9 +188,13 @@ async def aconn(dsn, request, tracefile): """Return an `AsyncConnection` connected to the ``--test-dsn`` database.""" check_connection_version(request.function) - from psycopg import AsyncConnection + cls = psycopg.AsyncConnection + if crdb_version: + from psycopg.crdb import AsyncCrdbConnection + + cls = AsyncCrdbConnection - conn = await AsyncConnection.connect(dsn) + conn = await cls.connect(dsn) with maybe_trace(conn.pgconn, tracefile, request.function): yield conn await conn.close() @@ -212,9 +218,7 @@ def svcconn(session_dsn): """ Return a session `Connection` connected to the ``--test-dsn`` database. """ - from psycopg import Connection - - conn = Connection.connect(session_dsn, autocommit=True) + conn = psycopg.Connection.connect(session_dsn, autocommit=True) yield conn conn.close() @@ -233,8 +237,6 @@ def acommands(aconn, monkeypatch): def patch_exec(conn, monkeypatch): """Helper to implement the commands fixture both sync and async.""" - from psycopg import sql - _orig_exec_command = conn._exec_command L = ListPopAll() @@ -285,12 +287,10 @@ def check_connection_version(function): @pytest.fixture def hstore(svcconn): - from psycopg import Error - try: with svcconn.transaction(): svcconn.execute("create extension if not exists hstore") - except Error as e: + except psycopg.Error as e: pytest.skip(str(e)) @@ -309,13 +309,14 @@ def warm_up_database(dsn: str, __first_connection: List[bool] = [True]) -> None: global pg_version, crdb_version - import psycopg - with psycopg.connect(dsn, connect_timeout=10) as conn: conn.execute("select 1") pg_version = conn.info.server_version - if conn.info.vendor == "CockroachDB": - crdb_version = conn.info.crdb_version # type: ignore - else: - crdb_version = None + + crdb_version = None + param = conn.info.parameter_status("crdb_version") + if param: + from psycopg.crdb import CrdbConnectionInfo + + crdb_version = CrdbConnectionInfo.parse_crdb_version(param) diff --git a/tools/update_oids.py b/tools/update_oids.py index 802af1f81..17be3ac24 100755 --- a/tools/update_oids.py +++ b/tools/update_oids.py @@ -2,44 +2,93 @@ """ Update the maps of builtin types and names. -You can update this file by executing it, using the PG* env var to connect +This script updates some of the files in psycopg source code with data read +from a database catalog. """ import re +import argparse import subprocess as sp from typing import List from pathlib import Path +import psycopg +from psycopg.rows import TupleRow +from psycopg.crdb import CrdbConnection +from psycopg._compat import TypeAlias + +Connection: TypeAlias = psycopg.Connection[TupleRow] + ROOT = Path(__file__).parent.parent -version_sql = """ -select format($$ -# Generated from PostgreSQL %s.%s -$$, - setting::int / 10000, setting::int % 100) -- assume PG >= 10 - from pg_settings - where name = 'server_version_num' -""" +def main() -> None: + opt = parse_cmdline() + conn = psycopg.connect(opt.dsn, autocommit=True) + + if CrdbConnection.is_crdb(conn): + conn = CrdbConnection.connect(opt.dsn, autocommit=True) + update_crdb_python_oids(conn) + else: + update_python_oids(conn) + update_cython_oids(conn) + + +def update_python_oids(conn: Connection) -> None: + fn = ROOT / "psycopg/psycopg/postgres.py" + + lines = [] + lines.extend(get_version_comment(conn)) + lines.extend(get_py_types(conn)) + lines.extend(get_py_ranges(conn)) + lines.extend(get_py_multiranges(conn)) + + update_file(fn, lines) + sp.check_call(["black", "-q", fn]) + + +def update_cython_oids(conn: Connection) -> None: + fn = ROOT / "psycopg_c/psycopg_c/_psycopg/oids.pxd" + + lines = [] + lines.extend(get_version_comment(conn)) + lines.extend(get_cython_oids(conn)) + + update_file(fn, lines) + + +def update_crdb_python_oids(conn: Connection) -> None: + fn = ROOT / "psycopg/psycopg/crdb.py" + + lines = [] + lines.extend(get_version_comment(conn)) + lines.extend(get_py_types(conn)) + + update_file(fn, lines) + sp.check_call(["black", "-q", fn]) + -# Note: "record" is a pseudotype but still a useful one to have. -# "pg_lsn" is a documented public type and useful in streaming replication -# treat "char" (with quotes) separately. -py_types_sql = """ -select - 'TypeInfo(' - || array_to_string(array_remove(array[ - format('%L', typname), - oid::text, - typarray::text, - case when oid::regtype::text != typname - then format('regtype=%L', oid::regtype) - end, - case when typdelim != ',' - then format('delimiter=%L', typdelim) - end - ], null), ',') - || '),' +def get_version_comment(conn: Connection) -> List[str]: + if conn.info.vendor == "PostgreSQL": + # Assume PG > 10 + num = conn.info.server_version + version = f"{num // 10000}.{num % 100}" + elif conn.info.vendor == "CockroachDB": + assert isinstance(conn, CrdbConnection) + num = conn.info.crdb_version + version = f"{num // 10000}.{num % 10000 // 100}.{num % 100}" + else: + raise NotImplementedError(f"unexpected vendor: {conn.info.vendor}") + return ["", f" # Generated from {conn.info.vendor} {version}", ""] + + +def get_py_types(conn: Connection) -> List[str]: + # Note: "record" is a pseudotype but still a useful one to have. + # "pg_lsn" is a documented public type and useful in streaming replication + lines = [] + for (typname, oid, typarray, regtype, typdelim) in conn.execute( + """ +select typname, oid, typarray, typname::regtype::text as regtype, typdelim from pg_type t where oid < 10000 @@ -48,37 +97,73 @@ where and (typname !~ '^(_|pg_)' or typname = 'pg_lsn') order by typname """ - -py_ranges_sql = """ -select - format('RangeInfo(%L, %s, %s, subtype_oid=%s),', - typname, oid, typarray, rngsubtype) + ): + # Weird legacy type in postgres catalog + if typname == "char": + typname = regtype = '"char"' + + # https://github.com/cockroachdb/cockroach/issues/81645 + if typname == "int4" and conn.info.vendor == "CockroachDB": + regtype = typname + + params = [f"{typname!r}, {oid}, {typarray}"] + if regtype != typname: + params.append(f"regtype={regtype!r}") + if typdelim != ",": + params.append(f"delimiter={typdelim!r}") + lines.append(f"TypeInfo({','.join(params)}),") + + return lines + + +def get_py_ranges(conn: Connection) -> List[str]: + lines = [] + for (typname, oid, typarray, rngsubtype) in conn.execute( + """ +select typname, oid, typarray, rngsubtype from pg_type t join pg_range r on t.oid = rngtypid where oid < 10000 and typtype = 'r' - and (typname !~ '^(_|pg_)' or typname = 'pg_lsn') order by typname """ + ): + params = [f"{typname!r}, {oid}, {typarray}, subtype_oid={rngsubtype}"] + lines.append(f"RangeInfo({','.join(params)}),") + + return lines -py_multiranges_sql = """ -select - format('MultirangeInfo(%L, %s, %s, range_oid=%s, subtype_oid=%s),', - typname, oid, typarray, rngtypid, rngsubtype) + +def get_py_multiranges(conn: Connection) -> List[str]: + lines = [] + for (typname, oid, typarray, rngtypid, rngsubtype) in conn.execute( + """ +select typname, oid, typarray, rngtypid, rngsubtype from pg_type t join pg_range r on t.oid = rngmultitypid where oid < 10000 and typtype = 'm' - and (typname !~ '^(_|pg_)' or typname = 'pg_lsn') order by typname """ + ): + params = [ + f"{typname!r}, {oid}, {typarray}," + f" range_oid={rngtypid}, subtype_oid={rngsubtype}" + ] + lines.append(f"MultirangeInfo({','.join(params)}),") + + return lines -cython_oids_sql = """ -select format('%s_OID = %s', upper(typname), oid) + +def get_cython_oids(conn: Connection) -> List[str]: + lines = [] + for (typname, oid) in conn.execute( + """ +select typname, oid from pg_type where oid < 10000 @@ -86,43 +171,34 @@ where and (typname !~ '^(_|pg_)' or typname = 'pg_lsn') order by typname """ + ): + const_name = typname.upper() + "_OID" + lines.append(f" {const_name} = {oid}") - -def update_python_oids() -> None: - queries = [version_sql, py_types_sql, py_ranges_sql, py_multiranges_sql] - fn = ROOT / "psycopg/psycopg/postgres.py" - update_file(fn, queries) - sp.check_call(["black", "-q", fn]) - - -def update_cython_oids() -> None: - queries = [version_sql, cython_oids_sql] - fn = ROOT / "psycopg_c/psycopg_c/_psycopg/oids.pxd" - update_file(fn, queries) + return lines -def update_file(fn: Path, queries: List[str]) -> None: - with fn.open("rb") as f: +def update_file(fn: Path, new: List[str]) -> None: + with fn.open("r") as f: lines = f.read().splitlines() - - new = [] - for query in queries: - out = sp.run(["psql", "-AXqt", "-c", query], stdout=sp.PIPE, check=True) - new.extend(out.stdout.splitlines()) - - new = [b" " * 4 + line if line else b"" for line in new] # indent istart, iend = [ i for i, line in enumerate(lines) - if re.match(rb"\s*#\s*autogenerated:\s+(start|end)", line) + if re.match(r"\s*#\s*autogenerated:\s+(start|end)", line) ] lines[istart + 1 : iend] = new - with fn.open("wb") as f: - f.write(b"\n".join(lines)) - f.write(b"\n") + with fn.open("w") as f: + f.write("\n".join(lines)) + f.write("\n") + + +def parse_cmdline() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("dsn", help="where to connect to") + opt = parser.parse_args() + return opt if __name__ == "__main__": - update_python_oids() - update_cython_oids() + main()