]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added cast from varchar, bpchar, name
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 8 Apr 2020 06:15:36 +0000 (18:15 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 8 Apr 2020 06:15:36 +0000 (18:15 +1200)
Note that bpchar (i.e. the output format for "char" - with quotes) and
name are always decoded, even if the db is SQL_ASCII, because these
types are mostly used in system catalogs and there should be no such
crap as unencded data there.

psycopg3/types/text.py
tests/types/test_text.py

index ae3673c877eb60c711f169b6bdad2a489c464ed5..d67bd1b5f9f79440be5023322c2498f57905dffe 100644 (file)
@@ -37,6 +37,8 @@ class StringAdapter(Adapter):
 
 @TypeCaster.text(builtins["text"].oid)
 @TypeCaster.binary(builtins["text"].oid)
+@TypeCaster.text(builtins["varchar"].oid)
+@TypeCaster.binary(builtins["varchar"].oid)
 class StringCaster(TypeCaster):
 
     decode: Optional[DecodeFunc]
@@ -60,6 +62,24 @@ class StringCaster(TypeCaster):
             return data
 
 
+@TypeCaster.text(builtins["name"].oid)
+@TypeCaster.binary(builtins["name"].oid)
+@TypeCaster.text(builtins["bpchar"].oid)
+@TypeCaster.binary(builtins["bpchar"].oid)
+class NameCaster(TypeCaster):
+    def __init__(self, oid: int, context: AdaptContext):
+        super().__init__(oid, context)
+
+        self.decode: DecodeFunc
+        if self.connection is not None:
+            self.decode = self.connection.codec.decode
+        else:
+            self.decode = codecs.lookup("utf8").decode
+
+    def cast(self, data: bytes) -> str:
+        return self.decode(data)[0]
+
+
 @Adapter.text(bytes)
 class BytesAdapter(Adapter):
     def __init__(self, src: type, context: AdaptContext = None):
index 37ee843ee0c059a870901db7f9a9384831c97570..cf49a907fcfd29ba894c83bfae254051d4cc1153 100644 (file)
@@ -20,12 +20,14 @@ def test_adapt_1char(conn, fmt_in):
         assert cur.fetchone()[0], chr(i)
 
 
+@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"])
 @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
-def test_cast_1char(conn, fmt_out):
+def test_cast_1char(conn, typename, fmt_out):
     cur = conn.cursor(binary=fmt_out == Format.BINARY)
     for i in range(1, 256):
-        cur.execute("select chr(%s::int)", (i,))
-        assert cur.fetchone()[0] == chr(i)
+        cur.execute(f"select chr(%s::int)::{typename}", (i,))
+        res = cur.fetchone()[0]
+        assert res == chr(i)
 
     assert cur.pgresult.fformat(0) == fmt_out
 
@@ -63,40 +65,58 @@ def test_adapt_badenc(conn, fmt_in):
 
 @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
 @pytest.mark.parametrize("encoding", ["utf8", "latin9"])
-def test_cast_enc(conn, fmt_out, encoding):
+@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"])
+def test_cast_enc(conn, typename, encoding, fmt_out):
     cur = conn.cursor(binary=fmt_out == Format.BINARY)
 
     conn.encoding = encoding
-    (res,) = cur.execute("select chr(%s::int)", (ord(eur),)).fetchone()
+    (res,) = cur.execute(
+        f"select chr(%s::int)::{typename}", (ord(eur),)
+    ).fetchone()
     assert res == eur
 
 
 @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
-def test_cast_badenc(conn, fmt_out):
+@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"])
+def test_cast_badenc(conn, typename, fmt_out):
     cur = conn.cursor(binary=fmt_out == Format.BINARY)
 
     conn.encoding = "latin1"
     with pytest.raises(psycopg3.DatabaseError):
-        cur.execute("select chr(%s::int)", (ord(eur),))
+        cur.execute(f"select chr(%s::int)::{typename}", (ord(eur),))
 
 
 @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
-def test_cast_ascii(conn, fmt_out):
+@pytest.mark.parametrize("typename", ["text", "varchar"])
+def test_cast_ascii(conn, typename, fmt_out):
     cur = conn.cursor(binary=fmt_out == Format.BINARY)
 
     conn.encoding = "sql_ascii"
-    (res,) = cur.execute("select chr(%s::int)", (ord(eur),)).fetchone()
+    (res,) = cur.execute(
+        f"select chr(%s::int)::{typename}", (ord(eur),)
+    ).fetchone()
     assert res == eur.encode("utf8")
 
 
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("typename", ["name", "bpchar"])
+def test_cast_ascii_encanyway(conn, typename, fmt_out):
+    cur = conn.cursor(binary=fmt_out == Format.BINARY)
+
+    conn.encoding = "sql_ascii"
+    (res,) = cur.execute(f"select 'aa'::{typename}").fetchone()
+    assert res == "aa"
+
+
 @pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
 @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
-def test_text_array(conn, fmt_in, fmt_out):
+@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"])
+def test_text_array(conn, typename, fmt_in, fmt_out):
     cur = conn.cursor(binary=fmt_out == Format.BINARY)
     ph = "%s" if fmt_in == Format.TEXT else "%b"
     a = list(map(chr, range(1, 256))) + [eur]
 
-    (res,) = cur.execute(f"select {ph}::text[]", (a,)).fetchone()
+    (res,) = cur.execute(f"select {ph}::{typename}[]", (a,)).fetchone()
     assert res == a