From 18bdc43c00eda38af07592c71cd22a5f1421c9bb Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Wed, 8 Apr 2020 18:15:36 +1200 Subject: [PATCH] Added cast from varchar, bpchar, name Note that bpchar (i.e. the output format for "char" - with quotes) and name are always decoded, even if the db is SQL_ASCII, because these types are mostly used in system catalogs and there should be no such crap as unencded data there. --- psycopg3/types/text.py | 20 +++++++++++++++++++ tests/types/test_text.py | 42 +++++++++++++++++++++++++++++----------- 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/psycopg3/types/text.py b/psycopg3/types/text.py index ae3673c87..d67bd1b5f 100644 --- a/psycopg3/types/text.py +++ b/psycopg3/types/text.py @@ -37,6 +37,8 @@ class StringAdapter(Adapter): @TypeCaster.text(builtins["text"].oid) @TypeCaster.binary(builtins["text"].oid) +@TypeCaster.text(builtins["varchar"].oid) +@TypeCaster.binary(builtins["varchar"].oid) class StringCaster(TypeCaster): decode: Optional[DecodeFunc] @@ -60,6 +62,24 @@ class StringCaster(TypeCaster): return data +@TypeCaster.text(builtins["name"].oid) +@TypeCaster.binary(builtins["name"].oid) +@TypeCaster.text(builtins["bpchar"].oid) +@TypeCaster.binary(builtins["bpchar"].oid) +class NameCaster(TypeCaster): + def __init__(self, oid: int, context: AdaptContext): + super().__init__(oid, context) + + self.decode: DecodeFunc + if self.connection is not None: + self.decode = self.connection.codec.decode + else: + self.decode = codecs.lookup("utf8").decode + + def cast(self, data: bytes) -> str: + return self.decode(data)[0] + + @Adapter.text(bytes) class BytesAdapter(Adapter): def __init__(self, src: type, context: AdaptContext = None): diff --git a/tests/types/test_text.py b/tests/types/test_text.py index 37ee843ee..cf49a907f 100644 --- a/tests/types/test_text.py +++ b/tests/types/test_text.py @@ -20,12 +20,14 @@ def test_adapt_1char(conn, fmt_in): assert cur.fetchone()[0], chr(i) +@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"]) @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) -def test_cast_1char(conn, fmt_out): +def test_cast_1char(conn, typename, fmt_out): cur = conn.cursor(binary=fmt_out == Format.BINARY) for i in range(1, 256): - cur.execute("select chr(%s::int)", (i,)) - assert cur.fetchone()[0] == chr(i) + cur.execute(f"select chr(%s::int)::{typename}", (i,)) + res = cur.fetchone()[0] + assert res == chr(i) assert cur.pgresult.fformat(0) == fmt_out @@ -63,40 +65,58 @@ def test_adapt_badenc(conn, fmt_in): @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) @pytest.mark.parametrize("encoding", ["utf8", "latin9"]) -def test_cast_enc(conn, fmt_out, encoding): +@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"]) +def test_cast_enc(conn, typename, encoding, fmt_out): cur = conn.cursor(binary=fmt_out == Format.BINARY) conn.encoding = encoding - (res,) = cur.execute("select chr(%s::int)", (ord(eur),)).fetchone() + (res,) = cur.execute( + f"select chr(%s::int)::{typename}", (ord(eur),) + ).fetchone() assert res == eur @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) -def test_cast_badenc(conn, fmt_out): +@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"]) +def test_cast_badenc(conn, typename, fmt_out): cur = conn.cursor(binary=fmt_out == Format.BINARY) conn.encoding = "latin1" with pytest.raises(psycopg3.DatabaseError): - cur.execute("select chr(%s::int)", (ord(eur),)) + cur.execute(f"select chr(%s::int)::{typename}", (ord(eur),)) @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) -def test_cast_ascii(conn, fmt_out): +@pytest.mark.parametrize("typename", ["text", "varchar"]) +def test_cast_ascii(conn, typename, fmt_out): cur = conn.cursor(binary=fmt_out == Format.BINARY) conn.encoding = "sql_ascii" - (res,) = cur.execute("select chr(%s::int)", (ord(eur),)).fetchone() + (res,) = cur.execute( + f"select chr(%s::int)::{typename}", (ord(eur),) + ).fetchone() assert res == eur.encode("utf8") +@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) +@pytest.mark.parametrize("typename", ["name", "bpchar"]) +def test_cast_ascii_encanyway(conn, typename, fmt_out): + cur = conn.cursor(binary=fmt_out == Format.BINARY) + + conn.encoding = "sql_ascii" + (res,) = cur.execute(f"select 'aa'::{typename}").fetchone() + assert res == "aa" + + @pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY]) @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) -def test_text_array(conn, fmt_in, fmt_out): +@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"]) +def test_text_array(conn, typename, fmt_in, fmt_out): cur = conn.cursor(binary=fmt_out == Format.BINARY) ph = "%s" if fmt_in == Format.TEXT else "%b" a = list(map(chr, range(1, 256))) + [eur] - (res,) = cur.execute(f"select {ph}::text[]", (a,)).fetchone() + (res,) = cur.execute(f"select {ph}::{typename}[]", (a,)).fetchone() assert res == a -- 2.47.3