]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Dump all string subclasses as text by default
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 17 Mar 2021 17:39:44 +0000 (18:39 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 17 Mar 2021 17:39:44 +0000 (18:39 +0100)
Useful for enums.

psycopg3/psycopg3/adapt.py
tests/test_adapt.py
tests/types/test_text.py

index 9d1f351012d88d75da4c0dfea323c2f9356a4599..afa82b1fa7ab51c2731300c24ddd7ead0568b7f6 100644 (file)
@@ -222,7 +222,7 @@ class AdaptersMap(AdaptContext):
         if format == Format.AUTO:
             # When dumping a string with %s we may refer to any type actually,
             # but the user surely passed a text format
-            if cls is str:
+            if issubclass(cls, str):
                 dmaps = [self._dumpers[pq.Format.TEXT]]
             else:
                 dmaps = [
index 27607d92618e888024b32bf13cd385d1c967270f..1465803bcce17dfb2798d09d74fbffaf786260c6 100644 (file)
@@ -48,7 +48,7 @@ def test_dump_connection_ctx(conn):
 
     cur = conn.cursor()
     cur.execute("select %s", [MyStr("hello")])
-    assert cur.fetchone() == ("hellob",)
+    assert cur.fetchone() == ("hellot",)
     cur.execute("select %t", [MyStr("hello")])
     assert cur.fetchone() == ("hellot",)
     cur.execute("select %b", [MyStr("hello")])
@@ -64,7 +64,7 @@ def test_dump_cursor_ctx(conn):
     make_bin_dumper("bc").register(str, cur)
 
     cur.execute("select %s", [MyStr("hello")])
-    assert cur.fetchone() == ("hellobc",)
+    assert cur.fetchone() == ("hellotc",)
     cur.execute("select %t", [MyStr("hello")])
     assert cur.fetchone() == ("hellotc",)
     cur.execute("select %b", [MyStr("hello")])
@@ -72,7 +72,7 @@ def test_dump_cursor_ctx(conn):
 
     cur = conn.cursor()
     cur.execute("select %s", [MyStr("hello")])
-    assert cur.fetchone() == ("hellob",)
+    assert cur.fetchone() == ("hellot",)
     cur.execute("select %t", [MyStr("hello")])
     assert cur.fetchone() == ("hellot",)
     cur.execute("select %b", [MyStr("hello")])
index d8c27b456350a45870ea13af3f7bdba49c80246d..998afb1b8ddd3a5af29a41ca10753dd0f9b693bf 100644 (file)
@@ -102,15 +102,19 @@ def test_dump_utf8_badenc(conn, fmt_in):
         cur.execute(f"select %{fmt_in}", ("\uddf8",))
 
 
-@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT])
 def test_dump_enum(conn, fmt_in):
     from enum import Enum
 
     class MyEnum(str, Enum):
         foo = "foo"
+        bar = "bar"
 
     cur = conn.cursor()
-    (res,) = cur.execute("select %s", (MyEnum.foo,)).fetchone()
+    cur.execute("create type myenum as enum ('foo', 'bar')")
+    cur.execute("create table with_enum (e myenum)")
+    cur.execute(f"insert into with_enum (e) values (%{fmt_in})", (MyEnum.foo,))
+    (res,) = cur.execute("select e from with_enum").fetchone()
     assert res == "foo"