From: Daniele Varrazzo Date: Thu, 21 Apr 2022 08:47:11 +0000 (+0200) Subject: fix(enum): raise DataError in case of labels mismatch on load X-Git-Tag: 3.1~137^2~5 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0ccac16774fbae5c93727d20ef3b7b17e155ad89;p=thirdparty%2Fpsycopg.git fix(enum): raise DataError in case of labels mismatch on load A DataError subclass is already raised on mismatch on dump. Be consistent. --- diff --git a/psycopg/psycopg/types/enum.py b/psycopg/psycopg/types/enum.py index f79a7d7d8..a9965ea7c 100644 --- a/psycopg/psycopg/types/enum.py +++ b/psycopg/psycopg/types/enum.py @@ -4,13 +4,13 @@ Adapters for the enum type. from enum import Enum from typing import Type, Any, Dict, Generic, Optional, TypeVar, cast -from ..adapt import Dumper, Loader from .. import postgres -from .._encodings import pgconn_encoding -from .._typeinfo import EnumInfo as EnumInfo # exported here -from ..abc import AdaptContext -from ..adapt import Buffer +from .. import errors as e from ..pq import Format +from ..abc import AdaptContext +from ..adapt import Buffer, Dumper, Loader +from .._typeinfo import EnumInfo as EnumInfo # exported here +from .._encodings import pgconn_encoding E = TypeVar("E", bound=Enum) @@ -32,7 +32,12 @@ class EnumLoader(Loader, Generic[E]): else: label = data.decode(self._encoding) - return self.enum[label] + try: + return self.enum[label] + except KeyError: + raise e.DataError( + f"bad memeber for enum {self.enum.__qualname__}: {label!r}" + ) class EnumBinaryLoader(EnumLoader[E]): diff --git a/tests/types/test_enum.py b/tests/types/test_enum.py index 3cffe3729..156db54a6 100644 --- a/tests/types/test_enum.py +++ b/tests/types/test_enum.py @@ -2,7 +2,7 @@ from enum import Enum, auto import pytest -from psycopg import pq, sql +from psycopg import pq, sql, errors as e from psycopg.adapt import PyFormat from psycopg.types import TypeInfo from psycopg.types.enum import EnumInfo, register_enum @@ -241,3 +241,17 @@ async def test_enum_async(aconn, testenum, encoding, fmt_in, fmt_out): cur = await cur.execute(f"select %{fmt_in}", [list(enum)]) assert (await cur.fetchone())[0] == list(enum) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_error(conn, fmt_in, fmt_out): + conn.autocommit = True + + info = EnumInfo.fetch(conn, "puretestenum") + register_enum(info, conn, StrTestEnum) + + with pytest.raises(e.DataError): + conn.execute("select %s::text", [StrTestEnum.ONE]).fetchone() + with pytest.raises(e.DataError): + conn.execute("select 'BAR'::puretestenum").fetchone()