]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(crdb): allow None roundtrip
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 25 May 2022 08:48:43 +0000 (10:48 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 12 Jul 2022 11:58:34 +0000 (12:58 +0100)
Specify the TEXT oid for None, otherwise dumping with no context will
fail.

psycopg/psycopg/_transform.py
psycopg/psycopg/crdb.py
tests/test_cursor.py
tests/test_cursor_async.py
tests/test_prepared.py
tests/test_prepared_async.py

index ba2de6417a44da8f3ed08b404d1fc262048e54f0..93a4435d93c1455df096bf682e8c423713d28bb4 100644 (file)
@@ -51,11 +51,13 @@ class Transformer(AdaptContext):
         _oid_dumpers _oid_types _row_dumpers _row_loaders
         """.split()
 
-    _adapters: "AdaptersMap"
-    _pgresult: Optional["PGresult"]
     types: Optional[Tuple[int, ...]]
     formats: Optional[List[pq.Format]]
 
+    _adapters: "AdaptersMap"
+    _pgresult: Optional["PGresult"]
+    _none_oid: int
+
     def __init__(self, context: Optional[AdaptContext] = None):
         self._pgresult = self.types = self.formats = None
 
@@ -251,7 +253,7 @@ class Transformer(AdaptContext):
             dumper = cache[key1] = dumper.upgrade(obj, format)
             return dumper
 
-    def _get_none_oid(self):
+    def _get_none_oid(self) -> int:
         try:
             return self._none_oid
         except AttributeError:
index 62743fcdaaebb7aa14cbbbd23e95dbdc79112dff..2b0e60cd91ee264b157ac592426c37c69e91ab8b 100644 (file)
@@ -18,6 +18,7 @@ from .connection import Connection
 from ._adapters_map import AdaptersMap
 from .connection_async import AsyncConnection
 from .types.enum import EnumDumper, EnumBinaryDumper
+from .types.none import NoneDumper
 
 if TYPE_CHECKING:
     from .pq.abc import PGconn
@@ -184,18 +185,21 @@ class CrdbEnumBinaryDumper(EnumBinaryDumper):
     oid = TEXT_OID
 
 
+class CrdbNoneDumper(NoneDumper):
+    oid = TEXT_OID
+
+
 def register_postgres_adapters(context: AdaptContext) -> None:
     # Same adapters used by PostgreSQL, or a good starting point for customization
 
     from .types import array, bool, composite, datetime
-    from .types import json, none, numeric, string, uuid
+    from .types import json, numeric, string, uuid
 
     array.register_default_adapters(context)
     bool.register_default_adapters(context)
     composite.register_default_adapters(context)
     datetime.register_default_adapters(context)
     json.register_default_adapters(context)
-    none.register_default_adapters(context)
     numeric.register_default_adapters(context)
     string.register_default_adapters(context)
     uuid.register_default_adapters(context)
@@ -211,6 +215,7 @@ def register_crdb_adapters(context: AdaptContext) -> None:
     register_crdb_string_adapters(context)
     register_crdb_json_adapters(context)
     register_crdb_net_adapters(context)
+    register_crdb_none_adapters(context)
 
     array.register_all_arrays(adapters)
 
@@ -257,6 +262,10 @@ def register_crdb_net_adapters(context: AdaptContext) -> None:
     context.adapters.register_loader("inet", net.InetBinaryLoader)
 
 
+def register_crdb_none_adapters(context: AdaptContext) -> None:
+    context.adapters.register_dumper(type(None), CrdbNoneDumper)
+
+
 for t in [
     TypeInfo("json", 3802, 3807, regtype="jsonb"),  # Alias json -> jsonb.
     TypeInfo("int8", 20, 1016, regtype="integer"),  # Alias integer -> int8
index 8ffa303011b31797811500d86e99edce40b3cb2e..29ee2ebbd6486fa404d58dea66e151f3f9b6b3b3 100644 (file)
@@ -210,28 +210,25 @@ def test_fetchone(conn):
 
 
 def test_binary_cursor_execute(conn):
-    unk = "foo" if is_crdb(conn) else None
     cur = conn.cursor(binary=True)
-    cur.execute("select %s, %s", [1, unk])
-    assert cur.fetchone() == (1, unk)
+    cur.execute("select %s, %s", [1, None])
+    assert cur.fetchone() == (1, None)
     assert cur.pgresult.fformat(0) == 1
     assert cur.pgresult.get_value(0, 0) == b"\x00\x01"
 
 
 def test_execute_binary(conn):
     cur = conn.cursor()
-    unk = "foo" if is_crdb(conn) else None
-    cur.execute("select %s, %s", [1, unk], binary=True)
-    assert cur.fetchone() == (1, unk)
+    cur.execute("select %s, %s", [1, None], binary=True)
+    assert cur.fetchone() == (1, None)
     assert cur.pgresult.fformat(0) == 1
     assert cur.pgresult.get_value(0, 0) == b"\x00\x01"
 
 
 def test_binary_cursor_text_override(conn):
     cur = conn.cursor(binary=True)
-    unk = "foo" if is_crdb(conn) else None
-    cur.execute("select %s, %s", [1, unk], binary=False)
-    assert cur.fetchone() == (1, unk)
+    cur.execute("select %s, %s", [1, None], binary=False)
+    assert cur.fetchone() == (1, None)
     assert cur.pgresult.fformat(0) == 0
     assert cur.pgresult.get_value(0, 0) == b"1"
 
index 8b2b618f028dfadaa5f0553e5f5a4cb040691eda..2161a30d73175b5c669e8672ce58d4f9e6231db7 100644 (file)
@@ -11,7 +11,7 @@ from psycopg.adapt import PyFormat
 from .utils import gc_collect
 from .test_cursor import my_row_factory
 from .test_cursor import execmany, _execmany  # noqa: F401
-from .fix_crdb import is_crdb, crdb_encoding
+from .fix_crdb import crdb_encoding
 
 execmany = execmany  # avoid F811 underneath
 pytestmark = pytest.mark.asyncio
@@ -212,28 +212,25 @@ async def test_fetchone(aconn):
 
 
 async def test_binary_cursor_execute(aconn):
-    unk = "foo" if is_crdb(aconn) else None
     cur = aconn.cursor(binary=True)
-    await cur.execute("select %s, %s", [1, unk])
-    assert (await cur.fetchone()) == (1, unk)
+    await cur.execute("select %s, %s", [1, None])
+    assert (await cur.fetchone()) == (1, None)
     assert cur.pgresult.fformat(0) == 1
     assert cur.pgresult.get_value(0, 0) == b"\x00\x01"
 
 
 async def test_execute_binary(aconn):
-    unk = "foo" if is_crdb(aconn) else None
     cur = aconn.cursor()
-    await cur.execute("select %s, %s", [1, unk], binary=True)
-    assert (await cur.fetchone()) == (1, unk)
+    await cur.execute("select %s, %s", [1, None], binary=True)
+    assert (await cur.fetchone()) == (1, None)
     assert cur.pgresult.fformat(0) == 1
     assert cur.pgresult.get_value(0, 0) == b"\x00\x01"
 
 
 async def test_binary_cursor_text_override(aconn):
-    unk = "foo" if is_crdb(aconn) else None
     cur = aconn.cursor(binary=True)
-    await cur.execute("select %s, %s", [1, unk], binary=False)
-    assert (await cur.fetchone()) == (1, unk)
+    await cur.execute("select %s, %s", [1, None], binary=False)
+    assert (await cur.fetchone()) == (1, None)
     assert cur.pgresult.fformat(0) == 0
     assert cur.pgresult.get_value(0, 0) == b"1"
 
index 716496eab42583e74292e40735613a1c01ae51fa..dea238ca472a78da2c1647f74798cf742945ad9d 100644 (file)
@@ -10,8 +10,6 @@ import pytest
 import psycopg
 from psycopg.rows import namedtuple_row
 
-from .fix_crdb import is_crdb
-
 
 @pytest.mark.parametrize("value", [None, 0, 3])
 def test_prepare_threshold_init(dsn, value):
@@ -181,9 +179,7 @@ def test_evict_lru_deallocate(conn):
 
 def test_different_types(conn):
     conn.prepare_threshold = 0
-    # CRDB can't roundtrip None
-    unk = "foo" if is_crdb(conn) else None
-    conn.execute("select %s", [unk])
+    conn.execute("select %s", [None])
     conn.execute("select %s", [dt.date(2000, 1, 1)])
     conn.execute("select %s", [42])
     conn.execute("select %s", [41])
index 2983f2188179c746dfc4890ded99e0cda48f857b..a40a169e4667f7209c54f3ce7cf32c53194991cf 100644 (file)
@@ -10,8 +10,6 @@ import pytest
 import psycopg
 from psycopg.rows import namedtuple_row
 
-from .fix_crdb import is_crdb
-
 pytestmark = pytest.mark.asyncio
 
 
@@ -175,9 +173,7 @@ async def test_evict_lru_deallocate(aconn):
 
 async def test_different_types(aconn):
     aconn.prepare_threshold = 0
-    # CRDB can't roundtrip None
-    unk = "foo" if is_crdb(aconn) else None
-    await aconn.execute("select %s", [unk])
+    await aconn.execute("select %s", [None])
     await aconn.execute("select %s", [dt.date(2000, 1, 1)])
     await aconn.execute("select %s", [42])
     await aconn.execute("select %s", [41])