# Copyright (C) 2020 The Psycopg Team
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
+from typing import cast
from psycopg3.pq.proto import PGresult
from psycopg3.pq.enums import DiagnosticField
"""
+ErrorInfo = Union[None, PGresult, Dict[int, Optional[bytes]]]
+
+
class Error(Exception):
"""
Base exception for all the errors psycopg3 will raise.
def __init__(
self,
*args: Sequence[Any],
- pgresult: Optional[PGresult] = None,
+ info: ErrorInfo = None,
encoding: str = "utf-8"
):
super().__init__(*args)
- self.pgresult = pgresult
+ self._info = info
self._encoding = encoding
@property
def diag(self) -> "Diagnostic":
- return Diagnostic(self.pgresult, encoding=self._encoding)
+ return Diagnostic(self._info, encoding=self._encoding)
def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
res = super().__reduce__()
if isinstance(res, tuple) and len(res) >= 3:
- res[2]['pgresult'] = None
+ res[2]["_info"] = self._info_to_dict(self._info)
return res
+ @classmethod
+ def _info_to_dict(cls, info: ErrorInfo) -> ErrorInfo:
+ """
+ Convert a PGresult to a dictionary to make the info picklable.
+ """
+ # PGresult is a protocol, can't use isinstance
+ if hasattr(info, "error_field"):
+ info = cast(PGresult, info)
+ return {v: info.error_field(v) for v in DiagnosticField}
+ else:
+ return info
+
class InterfaceError(Error):
"""
class Diagnostic:
- def __init__(self, pgresult: Optional[PGresult], encoding: str = "utf-8"):
- self.pgresult = pgresult
- self.encoding = encoding
+ def __init__(self, info: ErrorInfo, encoding: str = "utf-8"):
+ self._info = info
+ self._encoding = encoding
@property
def severity(self) -> Optional[str]:
return self._error_message(DiagnosticField.SOURCE_FUNCTION)
def _error_message(self, field: DiagnosticField) -> Optional[str]:
- if self.pgresult is not None:
- val = self.pgresult.error_field(field)
+ if self._info:
+ if isinstance(self._info, dict):
+ val = self._info.get(field)
+ else:
+ val = self._info.error_field(field)
+
if val is not None:
- return val.decode(self.encoding, "replace")
+ return val.decode(self._encoding, "replace")
return None
+ def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
+ res = super().__reduce__()
+ if isinstance(res, tuple) and len(res) >= 3:
+ res[2]["_info"] = Error._info_to_dict(self._info)
+
+ return res
+
def lookup(sqlstate: str) -> Type[Error]:
return _sqlcodes[sqlstate]
cls = _class_for_state(state.decode("ascii"))
return cls(
pq.error_message(result, encoding=encoding),
- pgresult=result,
+ info=result,
encoding=encoding,
)
exc = pickle.loads(pickle.dumps(excinfo.value))
assert isinstance(exc, e.UndefinedTable)
- assert exc.pgresult is None
+ assert exc.diag.sqlstate == "42P01"
+
+
+def test_diag_pickle(conn):
+ cur = conn.cursor()
+ with pytest.raises(e.DatabaseError) as excinfo:
+ cur.execute("select 1 from wat")
+
+ diag1 = excinfo.value.diag
+ diag2 = pickle.loads(pickle.dumps(diag1))
+
+ assert isinstance(diag2, type(diag1))
+ for f in pq.DiagnosticField:
+ assert getattr(diag1, f.name.lower()) == getattr(diag2, f.name.lower())
+
+ assert diag2.sqlstate == "42P01"