]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(enum): raise DataError in case of labels mismatch on load
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 21 Apr 2022 08:47:11 +0000 (10:47 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 22 Apr 2022 03:03:24 +0000 (05:03 +0200)
A DataError subclass is already raised on mismatch on dump. Be
consistent.

psycopg/psycopg/types/enum.py
tests/types/test_enum.py

index f79a7d7d89864b9f0856507120c04fbfbff1056f..a9965ea7c8d8af3cd713cefbcc640cba4d33522a 100644 (file)
@@ -4,13 +4,13 @@ Adapters for the enum type.
 from enum import Enum
 from typing import Type, Any, Dict, Generic, Optional, TypeVar, cast
 
-from ..adapt import Dumper, Loader
 from .. import postgres
-from .._encodings import pgconn_encoding
-from .._typeinfo import EnumInfo as EnumInfo  # exported here
-from ..abc import AdaptContext
-from ..adapt import Buffer
+from .. import errors as e
 from ..pq import Format
+from ..abc import AdaptContext
+from ..adapt import Buffer, Dumper, Loader
+from .._typeinfo import EnumInfo as EnumInfo  # exported here
+from .._encodings import pgconn_encoding
 
 
 E = TypeVar("E", bound=Enum)
@@ -32,7 +32,12 @@ class EnumLoader(Loader, Generic[E]):
         else:
             label = data.decode(self._encoding)
 
-        return self.enum[label]
+        try:
+            return self.enum[label]
+        except KeyError:
+            raise e.DataError(
+                f"bad memeber for enum {self.enum.__qualname__}: {label!r}"
+            )
 
 
 class EnumBinaryLoader(EnumLoader[E]):
index 3cffe372966f8ffa0c139fc69bfb7c1d4faf679a..156db54a6623e07fc0b18f88191596890bb55777 100644 (file)
@@ -2,7 +2,7 @@ from enum import Enum, auto
 
 import pytest
 
-from psycopg import pq, sql
+from psycopg import pq, sql, errors as e
 from psycopg.adapt import PyFormat
 from psycopg.types import TypeInfo
 from psycopg.types.enum import EnumInfo, register_enum
@@ -241,3 +241,17 @@ async def test_enum_async(aconn, testenum, encoding, fmt_in, fmt_out):
 
     cur = await cur.execute(f"select %{fmt_in}", [list(enum)])
     assert (await cur.fetchone())[0] == list(enum)
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_error(conn, fmt_in, fmt_out):
+    conn.autocommit = True
+
+    info = EnumInfo.fetch(conn, "puretestenum")
+    register_enum(info, conn, StrTestEnum)
+
+    with pytest.raises(e.DataError):
+        conn.execute("select %s::text", [StrTestEnum.ONE]).fetchone()
+    with pytest.raises(e.DataError):
+        conn.execute("select 'BAR'::puretestenum").fetchone()