From ea17aaad16f345720843860e51e8c08d03bf53b1 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Wed, 17 Mar 2021 18:39:44 +0100 Subject: [PATCH] Dump all string subclasses as text by default Useful for enums. --- psycopg3/psycopg3/adapt.py | 2 +- tests/test_adapt.py | 6 +++--- tests/types/test_text.py | 8 ++++++-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/psycopg3/psycopg3/adapt.py b/psycopg3/psycopg3/adapt.py index 9d1f35101..afa82b1fa 100644 --- a/psycopg3/psycopg3/adapt.py +++ b/psycopg3/psycopg3/adapt.py @@ -222,7 +222,7 @@ class AdaptersMap(AdaptContext): if format == Format.AUTO: # When dumping a string with %s we may refer to any type actually, # but the user surely passed a text format - if cls is str: + if issubclass(cls, str): dmaps = [self._dumpers[pq.Format.TEXT]] else: dmaps = [ diff --git a/tests/test_adapt.py b/tests/test_adapt.py index 27607d926..1465803bc 100644 --- a/tests/test_adapt.py +++ b/tests/test_adapt.py @@ -48,7 +48,7 @@ def test_dump_connection_ctx(conn): cur = conn.cursor() cur.execute("select %s", [MyStr("hello")]) - assert cur.fetchone() == ("hellob",) + assert cur.fetchone() == ("hellot",) cur.execute("select %t", [MyStr("hello")]) assert cur.fetchone() == ("hellot",) cur.execute("select %b", [MyStr("hello")]) @@ -64,7 +64,7 @@ def test_dump_cursor_ctx(conn): make_bin_dumper("bc").register(str, cur) cur.execute("select %s", [MyStr("hello")]) - assert cur.fetchone() == ("hellobc",) + assert cur.fetchone() == ("hellotc",) cur.execute("select %t", [MyStr("hello")]) assert cur.fetchone() == ("hellotc",) cur.execute("select %b", [MyStr("hello")]) @@ -72,7 +72,7 @@ def test_dump_cursor_ctx(conn): cur = conn.cursor() cur.execute("select %s", [MyStr("hello")]) - assert cur.fetchone() == ("hellob",) + assert cur.fetchone() == ("hellot",) cur.execute("select %t", [MyStr("hello")]) assert cur.fetchone() == ("hellot",) cur.execute("select %b", [MyStr("hello")]) diff --git a/tests/types/test_text.py b/tests/types/test_text.py index d8c27b456..998afb1b8 100644 --- a/tests/types/test_text.py +++ b/tests/types/test_text.py @@ -102,15 +102,19 @@ def test_dump_utf8_badenc(conn, fmt_in): cur.execute(f"select %{fmt_in}", ("\uddf8",)) -@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY]) +@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT]) def test_dump_enum(conn, fmt_in): from enum import Enum class MyEnum(str, Enum): foo = "foo" + bar = "bar" cur = conn.cursor() - (res,) = cur.execute("select %s", (MyEnum.foo,)).fetchone() + cur.execute("create type myenum as enum ('foo', 'bar')") + cur.execute("create table with_enum (e myenum)") + cur.execute(f"insert into with_enum (e) values (%{fmt_in})", (MyEnum.foo,)) + (res,) = cur.execute("select e from with_enum").fetchone() assert res == "foo" -- 2.47.2