]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Use conninfo encoding to encode errors on connection
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 2 Jan 2022 14:37:54 +0000 (15:37 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 2 Jan 2022 15:18:18 +0000 (16:18 +0100)
PostgreSQL returns connection errors in the encoding specified in
lc_messages (e.g. EUC-JP for a server with `lc_messages=ja_JP.EUC-JP`).
However, because we don't have a connection, the message was decoded in
the fallback utf-8.

If the client has specified a client_encoding, use it to decode the
error message. Be more lenient than usual to look up the encoding
(because it is not normalised by the server and we don't care about
performance as this only happens on error handling). As for any other
error messages, still use an error=replace policy to avoid exploding
reporting an error in the wrong encoding.

Close #194.

docs/news.rst
psycopg/psycopg/_encodings.py
psycopg/psycopg/generators.py
psycopg_c/psycopg_c/_psycopg/generators.pyx
tests/test_encodings.py

index bdadad964a2b43ac070a85faef33e2fa5354b04e..a364751b8f5510c3838a1644f3ba578ceb322698 100644 (file)
@@ -17,6 +17,13 @@ Psycopg 3.1 (unreleased)
 Current release
 ---------------
 
+Psycopg 3.0.8
+^^^^^^^^^^^^^
+
+- Decode connection errors in the ``client_encoding`` specified in the
+  connection string, if available (:ticket:`#194`).
+
+
 Psycopg 3.0.7
 ^^^^^^^^^^^^^
 
index 83f3e116c1229f71e5b0317c911f213297bd7954..41196667848b61da9f2e5fce6b7f90a2e8e61bc3 100644 (file)
@@ -62,14 +62,43 @@ _py_codecs = {
 py_codecs: Dict[Union[bytes, str], str] = {}
 py_codecs.update((k.encode(), v) for k, v in _py_codecs.items())
 
+# Add an alias without underscore, for lenient lookups
+py_codecs.update(
+    (k.replace("_", "").encode(), v) for k, v in _py_codecs.items() if "_" in k
+)
+
 pg_codecs = {v: k.encode() for k, v in _py_codecs.items()}
 
 
 def pgconn_encoding(pgconn: "PGconn") -> str:
+    """
+    Return the Python encoding name of a connection.
+
+    Default to utf8 if the connection has no encoding info.
+    """
     pgenc = pgconn.parameter_status(b"client_encoding") or b"UTF8"
     return pg2pyenc(pgenc)
 
 
+def conninfo_encoding(conninfo: str) -> str:
+    """
+    Return the Python encoding name passed in a conninfo string. Default to utf8.
+
+    Because the input is likely to come from the user and not normalised by the
+    server, be somewhat lenient (non-case-sensitive lookup, ignore noise chars).
+    """
+    from .conninfo import conninfo_to_dict
+
+    params = conninfo_to_dict(conninfo)
+    pgenc = params.get("client_encoding")
+    if pgenc:
+        pgenc = pgenc.replace("-", "").replace("_", "").upper().encode()
+        if pgenc in py_codecs:
+            return py_codecs[pgenc]
+
+    return "utf-8"
+
+
 def py2pgenc(name: str) -> bytes:
     """Convert a Python encoding name to PostgreSQL encoding name.
 
index 54a81ef4c25d25af45ecddd6bed9aa68482c5045..c988fd52e887e6379af23e94cba494cc12e79bd9 100644 (file)
@@ -24,7 +24,7 @@ from .pq import ConnStatus, PollingStatus, ExecStatus
 from .abc import PQGen, PQGenConn
 from .pq.abc import PGconn, PGresult
 from .waiting import Wait, Ready
-from ._encodings import py_codecs
+from ._encodings import py_codecs, conninfo_encoding
 
 logger = logging.getLogger(__name__)
 
@@ -37,8 +37,9 @@ def connect(conninfo: str) -> PQGenConn[PGconn]:
     conn = pq.PGconn.connect_start(conninfo.encode())
     while 1:
         if conn.status == ConnStatus.BAD:
+            encoding = conninfo_encoding(conninfo)
             raise e.OperationalError(
-                f"connection is bad: {pq.error_message(conn)}"
+                f"connection is bad: {pq.error_message(conn, encoding=encoding)}"
             )
 
         status = conn.connect_poll()
@@ -49,8 +50,9 @@ def connect(conninfo: str) -> PQGenConn[PGconn]:
         elif status == PollingStatus.WRITING:
             yield conn.socket, Wait.W
         elif status == PollingStatus.FAILED:
+            encoding = conninfo_encoding(conninfo)
             raise e.OperationalError(
-                f"connection failed: {pq.error_message(conn)}"
+                f"connection failed: {pq.error_message(conn, encoding=encoding)}"
             )
         else:
             raise e.InternalError(f"unexpected poll status: {status}")
index 4cb3478acf572195ebd2ba74e33f28bf5f9a0ec6..d5def64015344fc882f26bf9c2a63598e89267fe 100644 (file)
@@ -13,6 +13,7 @@ from psycopg import errors as e
 from psycopg.pq import abc, error_message
 from psycopg.abc import PQGen
 from psycopg.waiting import Wait, Ready
+from psycopg._encodings import conninfo_encoding
 
 cdef object WAIT_W = Wait.W
 cdef object WAIT_R = Wait.R
@@ -32,8 +33,9 @@ def connect(conninfo: str) -> PQGenConn[abc.PGconn]:
 
     while 1:
         if conn_status == libpq.CONNECTION_BAD:
+            encoding = conninfo_encoding(conninfo)
             raise e.OperationalError(
-                f"connection is bad: {error_message(conn)}"
+                f"connection is bad: {error_message(conn, encoding=encoding)}"
             )
 
         poll_status = libpq.PQconnectPoll(pgconn_ptr)
@@ -45,8 +47,9 @@ def connect(conninfo: str) -> PQGenConn[abc.PGconn]:
         elif poll_status == libpq.PGRES_POLLING_WRITING:
             yield (libpq.PQsocket(pgconn_ptr), WAIT_W)
         elif poll_status == libpq.PGRES_POLLING_FAILED:
+            encoding = conninfo_encoding(conninfo)
             raise e.OperationalError(
-                f"connection failed: {error_message(conn)}"
+                f"connection failed: {error_message(conn, encoding=encoding)}"
             )
         else:
             raise e.InternalError(f"unexpected poll status: {poll_status}")
index 4bbaf68bdcf6762ebe910fff435733ce823376d6..113f0e38bdbf00986338b8d28221ce16b1e1ef0f 100644 (file)
@@ -41,3 +41,17 @@ def test_pg2py(pyenc, pgenc):
 def test_pg2py_missing(pgenc):
     with pytest.raises(psycopg.NotSupportedError):
         encodings.pg2pyenc(pgenc.encode())
+
+
+@pytest.mark.parametrize(
+    "conninfo, pyenc",
+    [
+        ("", "utf-8"),
+        ("user=foo, dbname=bar", "utf-8"),
+        ("user=foo, dbname=bar, client_encoding=EUC_JP", "euc_jp"),
+        ("user=foo, dbname=bar, client_encoding=euc-jp", "euc_jp"),
+        ("user=foo, dbname=bar, client_encoding=WAT", "utf-8"),
+    ],
+)
+def test_conninfo_encoding(conninfo, pyenc):
+    assert encodings.conninfo_encoding(conninfo) == pyenc