From: Daniele Varrazzo Date: Wed, 25 May 2022 08:48:43 +0000 (+0200) Subject: fix(crdb): allow None roundtrip X-Git-Tag: 3.1~49^2~42 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=66daacd745e9694e3e6af6f27a830b3795b405ca;p=thirdparty%2Fpsycopg.git fix(crdb): allow None roundtrip Specify the TEXT oid for None, otherwise dumping with no context will fail. --- diff --git a/psycopg/psycopg/_transform.py b/psycopg/psycopg/_transform.py index ba2de6417..93a4435d9 100644 --- a/psycopg/psycopg/_transform.py +++ b/psycopg/psycopg/_transform.py @@ -51,11 +51,13 @@ class Transformer(AdaptContext): _oid_dumpers _oid_types _row_dumpers _row_loaders """.split() - _adapters: "AdaptersMap" - _pgresult: Optional["PGresult"] types: Optional[Tuple[int, ...]] formats: Optional[List[pq.Format]] + _adapters: "AdaptersMap" + _pgresult: Optional["PGresult"] + _none_oid: int + def __init__(self, context: Optional[AdaptContext] = None): self._pgresult = self.types = self.formats = None @@ -251,7 +253,7 @@ class Transformer(AdaptContext): dumper = cache[key1] = dumper.upgrade(obj, format) return dumper - def _get_none_oid(self): + def _get_none_oid(self) -> int: try: return self._none_oid except AttributeError: diff --git a/psycopg/psycopg/crdb.py b/psycopg/psycopg/crdb.py index 62743fcda..2b0e60cd9 100644 --- a/psycopg/psycopg/crdb.py +++ b/psycopg/psycopg/crdb.py @@ -18,6 +18,7 @@ from .connection import Connection from ._adapters_map import AdaptersMap from .connection_async import AsyncConnection from .types.enum import EnumDumper, EnumBinaryDumper +from .types.none import NoneDumper if TYPE_CHECKING: from .pq.abc import PGconn @@ -184,18 +185,21 @@ class CrdbEnumBinaryDumper(EnumBinaryDumper): oid = TEXT_OID +class CrdbNoneDumper(NoneDumper): + oid = TEXT_OID + + def register_postgres_adapters(context: AdaptContext) -> None: # Same adapters used by PostgreSQL, or a good starting point for customization from .types import array, bool, composite, datetime - from .types import json, none, numeric, string, uuid + from .types import json, numeric, string, uuid array.register_default_adapters(context) bool.register_default_adapters(context) composite.register_default_adapters(context) datetime.register_default_adapters(context) json.register_default_adapters(context) - none.register_default_adapters(context) numeric.register_default_adapters(context) string.register_default_adapters(context) uuid.register_default_adapters(context) @@ -211,6 +215,7 @@ def register_crdb_adapters(context: AdaptContext) -> None: register_crdb_string_adapters(context) register_crdb_json_adapters(context) register_crdb_net_adapters(context) + register_crdb_none_adapters(context) array.register_all_arrays(adapters) @@ -257,6 +262,10 @@ def register_crdb_net_adapters(context: AdaptContext) -> None: context.adapters.register_loader("inet", net.InetBinaryLoader) +def register_crdb_none_adapters(context: AdaptContext) -> None: + context.adapters.register_dumper(type(None), CrdbNoneDumper) + + for t in [ TypeInfo("json", 3802, 3807, regtype="jsonb"), # Alias json -> jsonb. TypeInfo("int8", 20, 1016, regtype="integer"), # Alias integer -> int8 diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 8ffa30301..29ee2ebbd 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -210,28 +210,25 @@ def test_fetchone(conn): def test_binary_cursor_execute(conn): - unk = "foo" if is_crdb(conn) else None cur = conn.cursor(binary=True) - cur.execute("select %s, %s", [1, unk]) - assert cur.fetchone() == (1, unk) + cur.execute("select %s, %s", [1, None]) + assert cur.fetchone() == (1, None) assert cur.pgresult.fformat(0) == 1 assert cur.pgresult.get_value(0, 0) == b"\x00\x01" def test_execute_binary(conn): cur = conn.cursor() - unk = "foo" if is_crdb(conn) else None - cur.execute("select %s, %s", [1, unk], binary=True) - assert cur.fetchone() == (1, unk) + cur.execute("select %s, %s", [1, None], binary=True) + assert cur.fetchone() == (1, None) assert cur.pgresult.fformat(0) == 1 assert cur.pgresult.get_value(0, 0) == b"\x00\x01" def test_binary_cursor_text_override(conn): cur = conn.cursor(binary=True) - unk = "foo" if is_crdb(conn) else None - cur.execute("select %s, %s", [1, unk], binary=False) - assert cur.fetchone() == (1, unk) + cur.execute("select %s, %s", [1, None], binary=False) + assert cur.fetchone() == (1, None) assert cur.pgresult.fformat(0) == 0 assert cur.pgresult.get_value(0, 0) == b"1" diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py index 8b2b618f0..2161a30d7 100644 --- a/tests/test_cursor_async.py +++ b/tests/test_cursor_async.py @@ -11,7 +11,7 @@ from psycopg.adapt import PyFormat from .utils import gc_collect from .test_cursor import my_row_factory from .test_cursor import execmany, _execmany # noqa: F401 -from .fix_crdb import is_crdb, crdb_encoding +from .fix_crdb import crdb_encoding execmany = execmany # avoid F811 underneath pytestmark = pytest.mark.asyncio @@ -212,28 +212,25 @@ async def test_fetchone(aconn): async def test_binary_cursor_execute(aconn): - unk = "foo" if is_crdb(aconn) else None cur = aconn.cursor(binary=True) - await cur.execute("select %s, %s", [1, unk]) - assert (await cur.fetchone()) == (1, unk) + await cur.execute("select %s, %s", [1, None]) + assert (await cur.fetchone()) == (1, None) assert cur.pgresult.fformat(0) == 1 assert cur.pgresult.get_value(0, 0) == b"\x00\x01" async def test_execute_binary(aconn): - unk = "foo" if is_crdb(aconn) else None cur = aconn.cursor() - await cur.execute("select %s, %s", [1, unk], binary=True) - assert (await cur.fetchone()) == (1, unk) + await cur.execute("select %s, %s", [1, None], binary=True) + assert (await cur.fetchone()) == (1, None) assert cur.pgresult.fformat(0) == 1 assert cur.pgresult.get_value(0, 0) == b"\x00\x01" async def test_binary_cursor_text_override(aconn): - unk = "foo" if is_crdb(aconn) else None cur = aconn.cursor(binary=True) - await cur.execute("select %s, %s", [1, unk], binary=False) - assert (await cur.fetchone()) == (1, unk) + await cur.execute("select %s, %s", [1, None], binary=False) + assert (await cur.fetchone()) == (1, None) assert cur.pgresult.fformat(0) == 0 assert cur.pgresult.get_value(0, 0) == b"1" diff --git a/tests/test_prepared.py b/tests/test_prepared.py index 716496eab..dea238ca4 100644 --- a/tests/test_prepared.py +++ b/tests/test_prepared.py @@ -10,8 +10,6 @@ import pytest import psycopg from psycopg.rows import namedtuple_row -from .fix_crdb import is_crdb - @pytest.mark.parametrize("value", [None, 0, 3]) def test_prepare_threshold_init(dsn, value): @@ -181,9 +179,7 @@ def test_evict_lru_deallocate(conn): def test_different_types(conn): conn.prepare_threshold = 0 - # CRDB can't roundtrip None - unk = "foo" if is_crdb(conn) else None - conn.execute("select %s", [unk]) + conn.execute("select %s", [None]) conn.execute("select %s", [dt.date(2000, 1, 1)]) conn.execute("select %s", [42]) conn.execute("select %s", [41]) diff --git a/tests/test_prepared_async.py b/tests/test_prepared_async.py index 2983f2188..a40a169e4 100644 --- a/tests/test_prepared_async.py +++ b/tests/test_prepared_async.py @@ -10,8 +10,6 @@ import pytest import psycopg from psycopg.rows import namedtuple_row -from .fix_crdb import is_crdb - pytestmark = pytest.mark.asyncio @@ -175,9 +173,7 @@ async def test_evict_lru_deallocate(aconn): async def test_different_types(aconn): aconn.prepare_threshold = 0 - # CRDB can't roundtrip None - unk = "foo" if is_crdb(aconn) else None - await aconn.execute("select %s", [unk]) + await aconn.execute("select %s", [None]) await aconn.execute("select %s", [dt.date(2000, 1, 1)]) await aconn.execute("select %s", [42]) await aconn.execute("select %s", [41])