]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fix dump of enum subtypes
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 17 Mar 2021 17:10:04 +0000 (18:10 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 17 Mar 2021 17:10:04 +0000 (18:10 +0100)
psycopg3_c/psycopg3_c/_psycopg3/adapt.pyx
psycopg3_c/psycopg3_c/types/text.pyx
tests/types/test_numeric.py
tests/types/test_text.py

index e75fb426a5d72f8e2e4ce5e34fc96a3bd9eca2da..35545e23ff848854a19dead353bb18a08b11d64e 100644 (file)
@@ -35,7 +35,7 @@ cdef class CDumper:
     cdef public libpq.Oid oid
     cdef pq.PGconn _pgconn
 
-    def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+    def __init__(self, cls, context: Optional[AdaptContext] = None):
         self.cls = cls
         conn = context.connection if context is not None else None
         self._pgconn = conn.pgconn if conn is not None else None
index c80183aa6742d6f8281cda5596770cd14e804aec..808ce3142213d9562be6bf68a98753bec4464bba 100644 (file)
@@ -30,7 +30,7 @@ cdef class _StringDumper(CDumper):
     cdef char *encoding
     cdef bytes _bytes_encoding  # needed to keep `encoding` alive
 
-    def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+    def __init__(self, cls, context: Optional[AdaptContext] = None):
         super().__init__(cls, context)
 
         self.is_utf8 = 0
index 437601504d2d72b606e978286faee084b3924967..886a3143943671b8201a52dd69a89b6cd854acd2 100644 (file)
@@ -67,6 +67,18 @@ def test_dump_int_subtypes(conn, val, expr, fmt_in):
     assert cur.fetchone()[0] is True
 
 
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_dump_enum(conn, fmt_in):
+    import enum
+
+    class MyEnum(enum.IntEnum):
+        foo = 42
+
+    cur = conn.cursor()
+    (res,) = cur.execute("select %s", (MyEnum.foo,)).fetchone()
+    assert res == 42
+
+
 @pytest.mark.parametrize(
     "val, expr",
     [
index c18dcd35818792c8f3e732c823605c73bc9dce71..d8c27b456350a45870ea13af3f7bdba49c80246d 100644 (file)
@@ -102,6 +102,18 @@ def test_dump_utf8_badenc(conn, fmt_in):
         cur.execute(f"select %{fmt_in}", ("\uddf8",))
 
 
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_dump_enum(conn, fmt_in):
+    from enum import Enum
+
+    class MyEnum(str, Enum):
+        foo = "foo"
+
+    cur = conn.cursor()
+    (res,) = cur.execute("select %s", (MyEnum.foo,)).fetchone()
+    assert res == "foo"
+
+
 @pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
 @pytest.mark.parametrize("encoding", ["utf8", "latin9"])
 @pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"])