From: Daniele Varrazzo Date: Wed, 28 Oct 2020 20:46:52 +0000 (+0100) Subject: Diagnostic objects can be pickled and error info survive pickling X-Git-Tag: 3.0.dev0~420 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b6c0b14c9b77f470cd7ceb4054053a3c56fa100a;p=thirdparty%2Fpsycopg.git Diagnostic objects can be pickled and error info survive pickling --- diff --git a/psycopg3/psycopg3/errors.py b/psycopg3/psycopg3/errors.py index 96279fc17..edec693d9 100644 --- a/psycopg3/psycopg3/errors.py +++ b/psycopg3/psycopg3/errors.py @@ -19,6 +19,7 @@ DBAPI-defined Exceptions are defined in the following hierarchy:: # Copyright (C) 2020 The Psycopg Team from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union +from typing import cast from psycopg3.pq.proto import PGresult from psycopg3.pq.enums import DiagnosticField @@ -31,6 +32,9 @@ class Warning(Exception): """ +ErrorInfo = Union[None, PGresult, Dict[int, Optional[bytes]]] + + class Error(Exception): """ Base exception for all the errors psycopg3 will raise. @@ -39,24 +43,36 @@ class Error(Exception): def __init__( self, *args: Sequence[Any], - pgresult: Optional[PGresult] = None, + info: ErrorInfo = None, encoding: str = "utf-8" ): super().__init__(*args) - self.pgresult = pgresult + self._info = info self._encoding = encoding @property def diag(self) -> "Diagnostic": - return Diagnostic(self.pgresult, encoding=self._encoding) + return Diagnostic(self._info, encoding=self._encoding) def __reduce__(self) -> Union[str, Tuple[Any, ...]]: res = super().__reduce__() if isinstance(res, tuple) and len(res) >= 3: - res[2]['pgresult'] = None + res[2]["_info"] = self._info_to_dict(self._info) return res + @classmethod + def _info_to_dict(cls, info: ErrorInfo) -> ErrorInfo: + """ + Convert a PGresult to a dictionary to make the info picklable. + """ + # PGresult is a protocol, can't use isinstance + if hasattr(info, "error_field"): + info = cast(PGresult, info) + return {v: info.error_field(v) for v in DiagnosticField} + else: + return info + class InterfaceError(Error): """ @@ -122,9 +138,9 @@ class NotSupportedError(DatabaseError): class Diagnostic: - def __init__(self, pgresult: Optional[PGresult], encoding: str = "utf-8"): - self.pgresult = pgresult - self.encoding = encoding + def __init__(self, info: ErrorInfo, encoding: str = "utf-8"): + self._info = info + self._encoding = encoding @property def severity(self) -> Optional[str]: @@ -199,13 +215,24 @@ class Diagnostic: return self._error_message(DiagnosticField.SOURCE_FUNCTION) def _error_message(self, field: DiagnosticField) -> Optional[str]: - if self.pgresult is not None: - val = self.pgresult.error_field(field) + if self._info: + if isinstance(self._info, dict): + val = self._info.get(field) + else: + val = self._info.error_field(field) + if val is not None: - return val.decode(self.encoding, "replace") + return val.decode(self._encoding, "replace") return None + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + res = super().__reduce__() + if isinstance(res, tuple) and len(res) >= 3: + res[2]["_info"] = Error._info_to_dict(self._info) + + return res + def lookup(sqlstate: str) -> Type[Error]: return _sqlcodes[sqlstate] @@ -218,7 +245,7 @@ def error_from_result(result: PGresult, encoding: str = "utf-8") -> Error: cls = _class_for_state(state.decode("ascii")) return cls( pq.error_message(result, encoding=encoding), - pgresult=result, + info=result, encoding=encoding, ) diff --git a/tests/test_errors.py b/tests/test_errors.py index b4b34750c..f6da644a4 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -115,4 +115,19 @@ def test_error_pickle(conn): exc = pickle.loads(pickle.dumps(excinfo.value)) assert isinstance(exc, e.UndefinedTable) - assert exc.pgresult is None + assert exc.diag.sqlstate == "42P01" + + +def test_diag_pickle(conn): + cur = conn.cursor() + with pytest.raises(e.DatabaseError) as excinfo: + cur.execute("select 1 from wat") + + diag1 = excinfo.value.diag + diag2 = pickle.loads(pickle.dumps(diag1)) + + assert isinstance(diag2, type(diag1)) + for f in pq.DiagnosticField: + assert getattr(diag1, f.name.lower()) == getattr(diag2, f.name.lower()) + + assert diag2.sqlstate == "42P01"