From: Daniele Varrazzo Date: Thu, 21 Apr 2022 01:29:05 +0000 (+0200) Subject: fix: fix dumping int enums in text mode, python implementation X-Git-Tag: 3.1~139 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=fb458d6b0482f8478aa6a94787ae2d0922fb94a3;p=thirdparty%2Fpsycopg.git fix: fix dumping int enums in text mode, python implementation The error was hidden by a broken test, failing to test the text mode. --- diff --git a/docs/news.rst b/docs/news.rst index 66e43178b..d7a2cf8cc 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -26,6 +26,12 @@ Psycopg 3.1 (unreleased) - Drop support for Python 3.6. +Psycopg 3.0.12 (unreleased) +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +- Fix dumping `~enum.IntEnum` in text mode, Python implementation. + + Current release --------------- diff --git a/psycopg/psycopg/types/numeric.py b/psycopg/psycopg/types/numeric.py index cc5ed1a7e..dc4d52d54 100644 --- a/psycopg/psycopg/types/numeric.py +++ b/psycopg/psycopg/types/numeric.py @@ -31,19 +31,23 @@ from .._wrappers import ( ) -class _NumberDumper(Dumper): +class _IntDumper(Dumper): def dump(self, obj: Any) -> bytes: - return str(obj).encode() + # Convert to int in order to dump IntEnum correctly + return str(int(obj)).encode() def quote(self, obj: Any) -> bytes: value = self.dump(obj) return value if obj >= 0 else b" " + value -class _SpecialValuesDumper(_NumberDumper): +class _SpecialValuesDumper(Dumper): _special: Dict[bytes, bytes] = {} + def dump(self, obj: Any) -> bytes: + return str(obj).encode() + def quote(self, obj: Any) -> bytes: value = self.dump(obj) @@ -103,23 +107,23 @@ class DecimalDumper(_SpecialValuesDumper): } -class Int2Dumper(_NumberDumper): +class Int2Dumper(_IntDumper): oid = postgres.types["int2"].oid -class Int4Dumper(_NumberDumper): +class Int4Dumper(_IntDumper): oid = postgres.types["int4"].oid -class Int8Dumper(_NumberDumper): +class Int8Dumper(_IntDumper): oid = postgres.types["int8"].oid -class IntNumericDumper(_NumberDumper): +class IntNumericDumper(_IntDumper): oid = postgres.types["numeric"].oid -class OidDumper(_NumberDumper): +class OidDumper(_IntDumper): oid = postgres.types["oid"].oid diff --git a/tests/types/test_numeric.py b/tests/types/test_numeric.py index d8fd8e069..a6d34e8b1 100644 --- a/tests/types/test_numeric.py +++ b/tests/types/test_numeric.py @@ -1,3 +1,4 @@ +import enum from decimal import Decimal from math import isnan, isinf, exp @@ -71,16 +72,21 @@ def test_dump_int_subtypes(conn, val, expr, fmt_in): assert ok -@pytest.mark.parametrize("fmt_in", PyFormat) -def test_dump_enum(conn, fmt_in): - import enum +class MyEnum(enum.IntEnum): + foo = 42 + - class MyEnum(enum.IntEnum): - foo = 42 +class MyMixinEnum(enum.IntEnum): + foo = 42000000 + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("enum", [MyEnum, MyMixinEnum]) +def test_dump_enum(conn, fmt_in, enum): cur = conn.cursor() - (res,) = cur.execute("select %s", (MyEnum.foo,)).fetchone() - assert res == 42 + cur.execute(f"select %{fmt_in}", (enum.foo,)) + (res,) = cur.fetchone() + assert res == enum.foo.value @pytest.mark.parametrize(