]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Raise error messages in the connection encoding
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 23 Jun 2020 11:17:10 +0000 (23:17 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 24 Jun 2020 08:36:29 +0000 (20:36 +1200)
psycopg3/connection.py
psycopg3/errors.py
psycopg3/pq/misc.py
tests/pq/test_misc.py

index 4d331299952b7660327970afe316837d5cb1d27f..597a1d430f1716c0bdd05ae1b6eae1884e5940cc 100644 (file)
@@ -248,7 +248,8 @@ class Connection(BaseConnection):
         (pgres,) = self.wait(execute(self.pgconn))
         if pgres.status != ExecStatus.COMMAND_OK:
             raise e.OperationalError(
-                f"error on begin: {pq.error_message(pgres)}"
+                "error on begin:"
+                f" {pq.error_message(pgres, encoding=self.codec.name)}"
             )
 
     def commit(self) -> None:
@@ -268,7 +269,7 @@ class Connection(BaseConnection):
             if pgres.status != ExecStatus.COMMAND_OK:
                 raise e.OperationalError(
                     f"error on {command.decode('utf8')}:"
-                    f" {pq.error_message(pgres)}"
+                    f" {pq.error_message(pgres, encoding=self.codec.name)}"
                 )
 
     @classmethod
@@ -286,7 +287,7 @@ class Connection(BaseConnection):
             gen = execute(self.pgconn)
             (result,) = self.wait(gen)
             if result.status != ExecStatus.TUPLES_OK:
-                raise e.error_from_result(result)
+                raise e.error_from_result(result, encoding=self.codec.name)
 
     def notifies(self) -> Generator[Optional[Notify], bool, None]:
         decode = self.codec.decode
@@ -355,7 +356,8 @@ class AsyncConnection(BaseConnection):
         (pgres,) = await self.wait(execute(self.pgconn))
         if pgres.status != ExecStatus.COMMAND_OK:
             raise e.OperationalError(
-                f"error on begin: {pq.error_message(pgres)}"
+                "error on begin:"
+                f" {pq.error_message(pgres, encoding=self.codec.name)}"
             )
 
     async def commit(self) -> None:
@@ -375,7 +377,7 @@ class AsyncConnection(BaseConnection):
             if pgres.status != ExecStatus.COMMAND_OK:
                 raise e.OperationalError(
                     f"error on {command.decode('utf8')}:"
-                    f" {pq.error_message(pgres)}"
+                    f" {pq.error_message(pgres, encoding=self.codec.name)}"
                 )
 
     @classmethod
@@ -391,7 +393,7 @@ class AsyncConnection(BaseConnection):
             gen = execute(self.pgconn)
             (result,) = await self.wait(gen)
             if result.status != ExecStatus.TUPLES_OK:
-                raise e.error_from_result(result)
+                raise e.error_from_result(result, encoding=self.codec.name)
 
     async def notifies(self) -> AsyncGenerator[Optional[Notify], bool]:
         decode = self.codec.decode
index c16f7390b45a92fd347320ea7e308c12cc50443b..f69109a0bb59b0a9c241335a9741600f5ff7bb64 100644 (file)
@@ -209,7 +209,11 @@ def error_from_result(result: PGresult, encoding: str = "utf-8") -> Error:
 
     state = result.error_field(DiagnosticField.SQLSTATE) or b""
     cls = _class_for_state(state.decode("ascii"))
-    return cls(pq.error_message(result), pgresult=result, encoding=encoding)
+    return cls(
+        pq.error_message(result, encoding=encoding),
+        pgresult=result,
+        encoding=encoding,
+    )
 
 
 def _class_for_state(sqlstate: str) -> Type[Error]:
index 73b77b0c439b871eb051ea5253fefad8a61d4e91..782e991cc979da1495cf0100597c01ad96db7b9d 100644 (file)
@@ -7,8 +7,9 @@ Various functionalities to make easier to work with the libpq.
 from typing import cast, NamedTuple, Optional, Union
 
 from ..errors import OperationalError
-from .enums import DiagnosticField
+from .enums import DiagnosticField, ConnStatus
 from .proto import PGconn, PGresult
+from .encodings import py_codecs
 
 
 class PQerror(OperationalError):
@@ -41,15 +42,19 @@ class PGresAttDesc(NamedTuple):
     atttypmod: int
 
 
-def error_message(obj: Union[PGconn, PGresult]) -> str:
+def error_message(obj: Union[PGconn, PGresult], encoding: str = "utf8") -> str:
     """
     Return an error message from a PGconn or PGresult.
 
-    The return value is a str (unlike pq data which is usually bytes).
+    The return value is a str (unlike pq data which is usually bytes): use
+    the connection encoding if available, otherwise the *encoding* parameter
+    as a fallback for decoding. Don't raise exception on decode errors.
+
     """
     bmsg: bytes
 
     if hasattr(obj, "error_field"):
+        # obj is a PGresult
         obj = cast(PGresult, obj)
 
         bmsg = obj.error_field(DiagnosticField.MESSAGE_PRIMARY) or b""
@@ -62,6 +67,11 @@ def error_message(obj: Union[PGconn, PGresult]) -> str:
 
     elif hasattr(obj, "error_message"):
         # obj is a PGconn
+        obj = cast(PGconn, obj)
+        if obj.status == ConnStatus.OK:
+            encoding = py_codecs.get(
+                obj.parameter_status(b"client_encoding"), "utf8"
+            )
         bmsg = obj.error_message
 
         # strip severity and whitespaces
@@ -74,9 +84,7 @@ def error_message(obj: Union[PGconn, PGresult]) -> str:
         )
 
     if bmsg:
-        msg = bmsg.decode(
-            "utf8", "replace"
-        )  # TODO: or in connection encoding?
+        msg = bmsg.decode(encoding, "replace")
     else:
         msg = "no details available"
 
index 891d97bc33edccbd9f9b7758111f503aff98b625..8545cfbac6f84e105e704305c7dcc5ebba277d13 100644 (file)
@@ -22,6 +22,26 @@ def test_error_message(pgconn):
     assert "NULL" in pq.error_message(pgconn)
 
 
+def test_error_message_encoding(pgconn):
+    res = pgconn.exec_(b"set client_encoding to latin9")
+    assert res.status == pq.ExecStatus.COMMAND_OK
+
+    res = pgconn.exec_('select 1 from "foo\u20acbar"'.encode("latin9"))
+    assert res.status == pq.ExecStatus.FATAL_ERROR
+
+    msg = pq.error_message(pgconn)
+    assert "foo\u20acbar" in msg
+
+    msg = pq.error_message(res)
+    assert "foo\ufffdbar" in msg
+
+    msg = pq.error_message(res, encoding="latin9")
+    assert "foo\u20acbar" in msg
+
+    msg = pq.error_message(res, encoding="ascii")
+    assert "foo\ufffdbar" in msg
+
+
 def test_make_empty_result(pgconn):
     pgconn.exec_(b"wat")
     res = pgconn.make_empty_result(pq.ExecStatus.FATAL_ERROR)