]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Diagnostic objects can be pickled and error info survive pickling
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 28 Oct 2020 20:46:52 +0000 (21:46 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 28 Oct 2020 21:05:58 +0000 (22:05 +0100)
psycopg3/psycopg3/errors.py
tests/test_errors.py

index 96279fc17f1df66507d17d4c076e1cf40b887b93..edec693d9f0dcd4ccf6156e40da3fe9ba975f2c3 100644 (file)
@@ -19,6 +19,7 @@ DBAPI-defined Exceptions are defined in the following hierarchy::
 # Copyright (C) 2020 The Psycopg Team
 
 from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
+from typing import cast
 from psycopg3.pq.proto import PGresult
 from psycopg3.pq.enums import DiagnosticField
 
@@ -31,6 +32,9 @@ class Warning(Exception):
     """
 
 
+ErrorInfo = Union[None, PGresult, Dict[int, Optional[bytes]]]
+
+
 class Error(Exception):
     """
     Base exception for all the errors psycopg3 will raise.
@@ -39,24 +43,36 @@ class Error(Exception):
     def __init__(
         self,
         *args: Sequence[Any],
-        pgresult: Optional[PGresult] = None,
+        info: ErrorInfo = None,
         encoding: str = "utf-8"
     ):
         super().__init__(*args)
-        self.pgresult = pgresult
+        self._info = info
         self._encoding = encoding
 
     @property
     def diag(self) -> "Diagnostic":
-        return Diagnostic(self.pgresult, encoding=self._encoding)
+        return Diagnostic(self._info, encoding=self._encoding)
 
     def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
         res = super().__reduce__()
         if isinstance(res, tuple) and len(res) >= 3:
-            res[2]['pgresult'] = None
+            res[2]["_info"] = self._info_to_dict(self._info)
 
         return res
 
+    @classmethod
+    def _info_to_dict(cls, info: ErrorInfo) -> ErrorInfo:
+        """
+        Convert a PGresult to a dictionary to make the info picklable.
+        """
+        # PGresult is a protocol, can't use isinstance
+        if hasattr(info, "error_field"):
+            info = cast(PGresult, info)
+            return {v: info.error_field(v) for v in DiagnosticField}
+        else:
+            return info
+
 
 class InterfaceError(Error):
     """
@@ -122,9 +138,9 @@ class NotSupportedError(DatabaseError):
 
 
 class Diagnostic:
-    def __init__(self, pgresult: Optional[PGresult], encoding: str = "utf-8"):
-        self.pgresult = pgresult
-        self.encoding = encoding
+    def __init__(self, info: ErrorInfo, encoding: str = "utf-8"):
+        self._info = info
+        self._encoding = encoding
 
     @property
     def severity(self) -> Optional[str]:
@@ -199,13 +215,24 @@ class Diagnostic:
         return self._error_message(DiagnosticField.SOURCE_FUNCTION)
 
     def _error_message(self, field: DiagnosticField) -> Optional[str]:
-        if self.pgresult is not None:
-            val = self.pgresult.error_field(field)
+        if self._info:
+            if isinstance(self._info, dict):
+                val = self._info.get(field)
+            else:
+                val = self._info.error_field(field)
+
             if val is not None:
-                return val.decode(self.encoding, "replace")
+                return val.decode(self._encoding, "replace")
 
         return None
 
+    def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
+        res = super().__reduce__()
+        if isinstance(res, tuple) and len(res) >= 3:
+            res[2]["_info"] = Error._info_to_dict(self._info)
+
+        return res
+
 
 def lookup(sqlstate: str) -> Type[Error]:
     return _sqlcodes[sqlstate]
@@ -218,7 +245,7 @@ def error_from_result(result: PGresult, encoding: str = "utf-8") -> Error:
     cls = _class_for_state(state.decode("ascii"))
     return cls(
         pq.error_message(result, encoding=encoding),
-        pgresult=result,
+        info=result,
         encoding=encoding,
     )
 
index b4b34750c5b313fd87088c62f1ffb78fc2d2ce4f..f6da644a4029cfc575befc4ce3db13a72e7fa322 100644 (file)
@@ -115,4 +115,19 @@ def test_error_pickle(conn):
 
     exc = pickle.loads(pickle.dumps(excinfo.value))
     assert isinstance(exc, e.UndefinedTable)
-    assert exc.pgresult is None
+    assert exc.diag.sqlstate == "42P01"
+
+
+def test_diag_pickle(conn):
+    cur = conn.cursor()
+    with pytest.raises(e.DatabaseError) as excinfo:
+        cur.execute("select 1 from wat")
+
+    diag1 = excinfo.value.diag
+    diag2 = pickle.loads(pickle.dumps(diag1))
+
+    assert isinstance(diag2, type(diag1))
+    for f in pq.DiagnosticField:
+        assert getattr(diag1, f.name.lower()) == getattr(diag2, f.name.lower())
+
+    assert diag2.sqlstate == "42P01"