]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Handle sql_ascii encoding as binary
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 2 Apr 2020 14:57:09 +0000 (03:57 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 2 Apr 2020 14:57:33 +0000 (03:57 +1300)
psycopg3/connection.py
psycopg3/types/text.py
tests/test_cursor.py
tests/types/test_text.py

index bfe1237ca7a01f9e5cedea3604df4e7348c2bc40..23b4f1a50810fe389b91e118f01d4c7f95adb951 100644 (file)
@@ -67,6 +67,10 @@ class BaseConnection:
     def decode(self, b: bytes) -> str:
         return self.codec.decode(b)[0]
 
+    @property
+    def encoding(self) -> str:
+        return self.pgconn.parameter_status(b"client_encoding").decode("ascii")
+
     @classmethod
     def _connect_gen(cls, conninfo: str) -> ConnectGen:
         """
index 421142ee6b8f51b7e03cd5126800a9442f36868f..d8327376c88e1b4cafb5e584993df8e390c37e6f 100644 (file)
@@ -25,7 +25,10 @@ class StringAdapter(Adapter):
 
         self._encode: EncodeFunc
         if conn is not None:
-            self._encode = conn.codec.encode
+            if conn.encoding != "SQL_ASCII":
+                self._encode = conn.codec.encode
+            else:
+                self._encode = codecs.lookup("utf8").encode
         else:
             self._encode = codecs.lookup("utf8").encode
 
@@ -43,7 +46,7 @@ class StringCaster(Typecaster):
         super().__init__(oid, conn)
 
         if conn is not None:
-            if conn.pgenc != b"SQL_ASCII":
+            if conn.encoding != "SQL_ASCII":
                 self.decode = conn.codec.decode
             else:
                 self.decode = None
index 03c87aca873bd99d6f2bae41b273afbf6e85291a..5aaed71b57284db06c3cdf53ecd7b05cb8889967 100644 (file)
@@ -1,3 +1,6 @@
+import pytest
+
+
 def test_execute_many(conn):
     cur = conn.cursor()
     rv = cur.execute("select 'foo'; select 'bar'")
@@ -43,3 +46,18 @@ def test_execute_binary_result(conn):
     assert row[1] is None
     row = cur.fetchone()
     assert row is None
+
+
+@pytest.mark.parametrize("encoding", ["utf8", "latin9"])
+def test_query_encode(conn, encoding):
+    conn.encoding = encoding
+    cur = conn.cursor()
+    (res,) = cur.execute("select '\u20ac'").fetchone()
+    assert res == "\u20ac"
+
+
+def test_query_badenc(conn):
+    conn.encoding = "latin1"
+    cur = conn.cursor()
+    with pytest.raises(UnicodeEncodeError):
+        cur.execute("select '\u20ac'")
index ada260d8f3d9b8b7eeb5562eff60c5f297c02b78..ace5aacd491da01c545c8428adf38a8cead5a617 100644 (file)
@@ -1,5 +1,6 @@
 import pytest
 
+import psycopg3
 from psycopg3.adapt import Format
 
 
@@ -11,11 +12,9 @@ from psycopg3.adapt import Format
 @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
 def test_adapt_1char(conn, format):
     cur = conn.cursor()
-    query = "select %s = chr(%%s::int)" % (
-        "%s" if format == Format.TEXT else "%b"
-    )
+    ph = "%s" if format == Format.TEXT else "%b"
     for i in range(1, 256):
-        cur.execute(query, (chr(i), i))
+        cur.execute("select %s = chr(%%s::int)" % ph, (chr(i), i))
         assert cur.fetchone()[0], chr(i)
 
 
@@ -29,6 +28,71 @@ def test_cast_1char(conn, format):
     assert cur.pgresult.fformat(0) == format
 
 
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("encoding", ["utf8", "latin9"])
+def test_adapt_enc(conn, format, encoding):
+    eur = "\u20ac"
+    cur = conn.cursor()
+    ph = "%s" if format == Format.TEXT else "%b"
+
+    conn.encoding = encoding
+    (res,) = cur.execute("select %s::bytea" % ph, (eur,)).fetchone()
+    assert res == eur.encode("utf8")
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_adapt_ascii(conn, format):
+    eur = "\u20ac"
+    cur = conn.cursor(binary=format == Format.BINARY)
+    ph = "%s" if format == Format.TEXT else "%b"
+
+    conn.encoding = "sql_ascii"
+    (res,) = cur.execute("select ascii(%s)" % ph, (eur,)).fetchone()
+    assert res == ord(eur)
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_adapt_badenc(conn, format):
+    eur = "\u20ac"
+    cur = conn.cursor()
+    ph = "%s" if format == Format.TEXT else "%b"
+
+    conn.encoding = "latin1"
+    with pytest.raises(UnicodeEncodeError):
+        cur.execute("select %s::bytea" % ph, (eur,))
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("encoding", ["utf8", "latin9"])
+def test_cast_enc(conn, format, encoding):
+    eur = "\u20ac"
+    cur = conn.cursor(binary=format == Format.BINARY)
+
+    conn.encoding = encoding
+    (res,) = cur.execute("select chr(%s::int)", (ord(eur),)).fetchone()
+    assert res == eur
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_cast_badenc(conn, format):
+    eur = "\u20ac"
+    cur = conn.cursor(binary=format == Format.BINARY)
+
+    conn.encoding = "latin1"
+    with pytest.raises(psycopg3.DatabaseError):
+        cur.execute("select chr(%s::int)", (ord(eur),))
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_cast_ascii(conn, format):
+    eur = "\u20ac"
+    cur = conn.cursor(binary=format == Format.BINARY)
+
+    conn.encoding = "sql_ascii"
+    (res,) = cur.execute("select chr(%s::int)", (ord(eur),)).fetchone()
+    assert res == eur.encode("utf8")
+
+
 #
 # tests with bytea
 #
@@ -37,11 +101,10 @@ def test_cast_1char(conn, format):
 @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
 def test_adapt_1byte(conn, format):
     cur = conn.cursor()
-    query = "select %s = %%s::bytea" % (
-        "%s" if format == Format.TEXT else "%b"
-    )
+    ph = "%s" if format == Format.TEXT else "%b"
+    "select %s = %%s::bytea" % ph
     for i in range(0, 256):
-        cur.execute(query, (bytes([i]), fr"\x{i:02x}"))
+        cur.execute("select %s = %%s::bytea" % ph, (bytes([i]), fr"\x{i:02x}"))
         assert cur.fetchone()[0], i