From: Daniele Varrazzo Date: Thu, 2 Apr 2020 14:57:09 +0000 (+1300) Subject: Handle sql_ascii encoding as binary X-Git-Tag: 3.0.dev0~624 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=f8a8a3c77888dbfe92e1fba76ea7c9916496f8f3;p=thirdparty%2Fpsycopg.git Handle sql_ascii encoding as binary --- diff --git a/psycopg3/connection.py b/psycopg3/connection.py index bfe1237ca..23b4f1a50 100644 --- a/psycopg3/connection.py +++ b/psycopg3/connection.py @@ -67,6 +67,10 @@ class BaseConnection: def decode(self, b: bytes) -> str: return self.codec.decode(b)[0] + @property + def encoding(self) -> str: + return self.pgconn.parameter_status(b"client_encoding").decode("ascii") + @classmethod def _connect_gen(cls, conninfo: str) -> ConnectGen: """ diff --git a/psycopg3/types/text.py b/psycopg3/types/text.py index 421142ee6..d8327376c 100644 --- a/psycopg3/types/text.py +++ b/psycopg3/types/text.py @@ -25,7 +25,10 @@ class StringAdapter(Adapter): self._encode: EncodeFunc if conn is not None: - self._encode = conn.codec.encode + if conn.encoding != "SQL_ASCII": + self._encode = conn.codec.encode + else: + self._encode = codecs.lookup("utf8").encode else: self._encode = codecs.lookup("utf8").encode @@ -43,7 +46,7 @@ class StringCaster(Typecaster): super().__init__(oid, conn) if conn is not None: - if conn.pgenc != b"SQL_ASCII": + if conn.encoding != "SQL_ASCII": self.decode = conn.codec.decode else: self.decode = None diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 03c87aca8..5aaed71b5 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -1,3 +1,6 @@ +import pytest + + def test_execute_many(conn): cur = conn.cursor() rv = cur.execute("select 'foo'; select 'bar'") @@ -43,3 +46,18 @@ def test_execute_binary_result(conn): assert row[1] is None row = cur.fetchone() assert row is None + + +@pytest.mark.parametrize("encoding", ["utf8", "latin9"]) +def test_query_encode(conn, encoding): + conn.encoding = encoding + cur = conn.cursor() + (res,) = cur.execute("select '\u20ac'").fetchone() + assert res == "\u20ac" + + +def test_query_badenc(conn): + conn.encoding = "latin1" + cur = conn.cursor() + with pytest.raises(UnicodeEncodeError): + cur.execute("select '\u20ac'") diff --git a/tests/types/test_text.py b/tests/types/test_text.py index ada260d8f..ace5aacd4 100644 --- a/tests/types/test_text.py +++ b/tests/types/test_text.py @@ -1,5 +1,6 @@ import pytest +import psycopg3 from psycopg3.adapt import Format @@ -11,11 +12,9 @@ from psycopg3.adapt import Format @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) def test_adapt_1char(conn, format): cur = conn.cursor() - query = "select %s = chr(%%s::int)" % ( - "%s" if format == Format.TEXT else "%b" - ) + ph = "%s" if format == Format.TEXT else "%b" for i in range(1, 256): - cur.execute(query, (chr(i), i)) + cur.execute("select %s = chr(%%s::int)" % ph, (chr(i), i)) assert cur.fetchone()[0], chr(i) @@ -29,6 +28,71 @@ def test_cast_1char(conn, format): assert cur.pgresult.fformat(0) == format +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +@pytest.mark.parametrize("encoding", ["utf8", "latin9"]) +def test_adapt_enc(conn, format, encoding): + eur = "\u20ac" + cur = conn.cursor() + ph = "%s" if format == Format.TEXT else "%b" + + conn.encoding = encoding + (res,) = cur.execute("select %s::bytea" % ph, (eur,)).fetchone() + assert res == eur.encode("utf8") + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +def test_adapt_ascii(conn, format): + eur = "\u20ac" + cur = conn.cursor(binary=format == Format.BINARY) + ph = "%s" if format == Format.TEXT else "%b" + + conn.encoding = "sql_ascii" + (res,) = cur.execute("select ascii(%s)" % ph, (eur,)).fetchone() + assert res == ord(eur) + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +def test_adapt_badenc(conn, format): + eur = "\u20ac" + cur = conn.cursor() + ph = "%s" if format == Format.TEXT else "%b" + + conn.encoding = "latin1" + with pytest.raises(UnicodeEncodeError): + cur.execute("select %s::bytea" % ph, (eur,)) + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +@pytest.mark.parametrize("encoding", ["utf8", "latin9"]) +def test_cast_enc(conn, format, encoding): + eur = "\u20ac" + cur = conn.cursor(binary=format == Format.BINARY) + + conn.encoding = encoding + (res,) = cur.execute("select chr(%s::int)", (ord(eur),)).fetchone() + assert res == eur + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +def test_cast_badenc(conn, format): + eur = "\u20ac" + cur = conn.cursor(binary=format == Format.BINARY) + + conn.encoding = "latin1" + with pytest.raises(psycopg3.DatabaseError): + cur.execute("select chr(%s::int)", (ord(eur),)) + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +def test_cast_ascii(conn, format): + eur = "\u20ac" + cur = conn.cursor(binary=format == Format.BINARY) + + conn.encoding = "sql_ascii" + (res,) = cur.execute("select chr(%s::int)", (ord(eur),)).fetchone() + assert res == eur.encode("utf8") + + # # tests with bytea # @@ -37,11 +101,10 @@ def test_cast_1char(conn, format): @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) def test_adapt_1byte(conn, format): cur = conn.cursor() - query = "select %s = %%s::bytea" % ( - "%s" if format == Format.TEXT else "%b" - ) + ph = "%s" if format == Format.TEXT else "%b" + "select %s = %%s::bytea" % ph for i in range(0, 256): - cur.execute(query, (bytes([i]), fr"\x{i:02x}")) + cur.execute("select %s = %%s::bytea" % ph, (bytes([i]), fr"\x{i:02x}")) assert cur.fetchone()[0], i