From 92b30b0268a9c43dccbc3c0cd60a4285a923dac5 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Thu, 3 Aug 2023 14:02:23 +0100 Subject: [PATCH] feat(crdb): add numpy support --- psycopg/psycopg/crdb/_types.py | 17 ++++++++++++----- tests/types/test_numpy.py | 1 + 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/psycopg/psycopg/crdb/_types.py b/psycopg/psycopg/crdb/_types.py index 1cbd7d040..f19e7c6ce 100644 --- a/psycopg/psycopg/crdb/_types.py +++ b/psycopg/psycopg/crdb/_types.py @@ -37,12 +37,12 @@ def register_crdb_adapters(context: AdaptContext) -> None: _register_postgres_adapters(context) - # String must come after enum to map text oid -> string dumper + # String must come after enum and none to map text oid -> string dumper + _register_crdb_none_adapters(context) _register_crdb_enum_adapters(context) _register_crdb_string_adapters(context) _register_crdb_json_adapters(context) _register_crdb_net_adapters(context) - _register_crdb_none_adapters(context) dbapi20.register_dbapi20_adapters(adapters) @@ -53,16 +53,23 @@ def _register_postgres_adapters(context: AdaptContext) -> None: # Same adapters used by PostgreSQL, or a good starting point for customization from ..types import array, bool, composite, datetime - from ..types import numeric, string, uuid + from ..types import numeric, numpy, string, uuid array.register_default_adapters(context) - bool.register_default_adapters(context) composite.register_default_adapters(context) datetime.register_default_adapters(context) - numeric.register_default_adapters(context) string.register_default_adapters(context) uuid.register_default_adapters(context) + # Both numpy Decimal and uint64 dumpers use the numeric oid, but the former + # covers the entire numeric domain, whereas the latter only deals with + # integers. For this reason, if we specify dumpers by oid, we want to make + # sure to get the Decimal dumper. We enforce that by registering the + # numeric dumpers last. + numpy.register_default_adapters(context) + bool.register_default_adapters(context) + numeric.register_default_adapters(context) + def _register_crdb_string_adapters(context: AdaptContext) -> None: from ..types import string diff --git a/tests/types/test_numpy.py b/tests/types/test_numpy.py index 84e108243..ba5856cda 100644 --- a/tests/types/test_numpy.py +++ b/tests/types/test_numpy.py @@ -170,6 +170,7 @@ def test_dump_float(conn, nptype, val, pgtype, fmt_in): ], ) @pytest.mark.parametrize("fmt", Format) +@pytest.mark.crdb_skip("copy") def test_copy_by_oid(conn, val, nptype, pgtypes, fmt): nptype = getattr(np, nptype) val = nptype(val) -- 2.47.2