]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat(crdb): add numpy support 332/head
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 3 Aug 2023 13:02:23 +0000 (14:02 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 5 Aug 2023 14:21:30 +0000 (15:21 +0100)
psycopg/psycopg/crdb/_types.py
tests/types/test_numpy.py

index 1cbd7d04062d8eac6e2a5e92588c4278fc3b00ab..f19e7c6ce427a858b3a7253c5f076b700a5ddd53 100644 (file)
@@ -37,12 +37,12 @@ def register_crdb_adapters(context: AdaptContext) -> None:
 
     _register_postgres_adapters(context)
 
-    # String must come after enum to map text oid -> string dumper
+    # String must come after enum and none to map text oid -> string dumper
+    _register_crdb_none_adapters(context)
     _register_crdb_enum_adapters(context)
     _register_crdb_string_adapters(context)
     _register_crdb_json_adapters(context)
     _register_crdb_net_adapters(context)
-    _register_crdb_none_adapters(context)
 
     dbapi20.register_dbapi20_adapters(adapters)
 
@@ -53,16 +53,23 @@ def _register_postgres_adapters(context: AdaptContext) -> None:
     # Same adapters used by PostgreSQL, or a good starting point for customization
 
     from ..types import array, bool, composite, datetime
-    from ..types import numeric, string, uuid
+    from ..types import numeric, numpy, string, uuid
 
     array.register_default_adapters(context)
-    bool.register_default_adapters(context)
     composite.register_default_adapters(context)
     datetime.register_default_adapters(context)
-    numeric.register_default_adapters(context)
     string.register_default_adapters(context)
     uuid.register_default_adapters(context)
 
+    # Both numpy Decimal and uint64 dumpers use the numeric oid, but the former
+    # covers the entire numeric domain, whereas the latter only deals with
+    # integers. For this reason, if we specify dumpers by oid, we want to make
+    # sure to get the Decimal dumper. We enforce that by registering the
+    # numeric dumpers last.
+    numpy.register_default_adapters(context)
+    bool.register_default_adapters(context)
+    numeric.register_default_adapters(context)
+
 
 def _register_crdb_string_adapters(context: AdaptContext) -> None:
     from ..types import string
index 84e1082433e863c994c9d7c0ba969025633a8e62..ba5856cda839f9b4688f9e33d23b556b37b50a6c 100644 (file)
@@ -170,6 +170,7 @@ def test_dump_float(conn, nptype, val, pgtype, fmt_in):
     ],
 )
 @pytest.mark.parametrize("fmt", Format)
+@pytest.mark.crdb_skip("copy")
 def test_copy_by_oid(conn, val, nptype, pgtypes, fmt):
     nptype = getattr(np, nptype)
     val = nptype(val)