]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: add `Error.pgresult` property
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 12 Mar 2022 22:57:55 +0000 (22:57 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 13 Mar 2022 00:32:54 +0000 (00:32 +0000)
It might be useful for the same reason `pgconn` is.

For symmetry convert `pgconn` to a read-only attribute too.

psycopg/psycopg/_compat.py
psycopg/psycopg/errors.py
tests/test_errors.py

index 36bc892fd24c4a4c160ba1ee5cd1ec6936fc18c6..a09dd5cecc2eb3e429345b7169b75c1bf7ee3128 100644 (file)
@@ -34,10 +34,16 @@ else:
     from backports.zoneinfo import ZoneInfo
     from typing import Counter, Deque
 
+if sys.version_info >= (3, 10):
+    from typing import TypeGuard
+else:
+    from typing_extensions import TypeGuard
+
 __all__ = [
     "Counter",
     "Deque",
     "Protocol",
+    "TypeGuard",
     "ZoneInfo",
     "create_task",
 ]
index e5a94913adc96903ba963278f29b59ff0afb5041..e77fa8c39341a3025952a1cfe82d421be7707688 100644 (file)
@@ -19,10 +19,10 @@ DBAPI-defined Exceptions are defined in the following hierarchy::
 # Copyright (C) 2020 The Psycopg Team
 
 from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
-from typing import cast
 
-from psycopg.pq.abc import PGconn, PGresult
-from psycopg.pq._enums import DiagnosticField
+from .pq.abc import PGconn, PGresult
+from .pq._enums import DiagnosticField
+from ._compat import TypeGuard
 
 ErrorInfo = Union[None, PGresult, Dict[int, Optional[bytes]]]
 
@@ -52,7 +52,6 @@ class Error(Exception):
     __module__ = "psycopg"
 
     sqlstate: Optional[str] = None
-    pgconn: Optional[PGconn] = None
 
     def __init__(
         self,
@@ -64,12 +63,20 @@ class Error(Exception):
         super().__init__(*args)
         self._info = info
         self._encoding = encoding
-        self.pgconn = pgconn
+        self._pgconn = pgconn
 
         # Handle sqlstate codes for which we don't have a class.
         if not self.sqlstate and info:
             self.sqlstate = self.diag.sqlstate
 
+    @property
+    def pgconn(self) -> Optional[PGconn]:
+        return self._pgconn if self._pgconn else None
+
+    @property
+    def pgresult(self) -> Optional[PGresult]:
+        return self._info if _is_pgresult(self._info) else None
+
     @property
     def diag(self) -> "Diagnostic":
         """
@@ -80,24 +87,12 @@ class Error(Exception):
     def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
         res = super().__reduce__()
         if isinstance(res, tuple) and len(res) >= 3:
-            res[2]["_info"] = self._info_to_dict(self._info)
             # To make the exception picklable
-            res[2]["pgconn"] = None
+            res[2]["_info"] = _info_to_dict(self._info)
+            res[2]["_pgconn"] = None
 
         return res
 
-    @classmethod
-    def _info_to_dict(cls, info: ErrorInfo) -> ErrorInfo:
-        """
-        Convert a PGresult to a dictionary to make the info picklable.
-        """
-        # PGresult is a protocol, can't use isinstance
-        if hasattr(info, "error_field"):
-            info = cast(PGresult, info)
-            return {v: info.error_field(v) for v in DiagnosticField}
-        else:
-            return info
-
 
 class InterfaceError(Error):
     """
@@ -285,11 +280,22 @@ class Diagnostic:
     def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
         res = super().__reduce__()
         if isinstance(res, tuple) and len(res) >= 3:
-            res[2]["_info"] = Error._info_to_dict(self._info)
+            res[2]["_info"] = _info_to_dict(self._info)
 
         return res
 
 
+def _info_to_dict(info: ErrorInfo) -> ErrorInfo:
+    """
+    Convert a PGresult to a dictionary to make the info picklable.
+    """
+    # PGresult is a protocol, can't use isinstance
+    if _is_pgresult(info):
+        return {v: info.error_field(v) for v in DiagnosticField}
+    else:
+        return info
+
+
 def lookup(sqlstate: str) -> Type[Error]:
     """Lookup an error code or `constant name`__ and return its exception class.
 
@@ -313,6 +319,12 @@ def error_from_result(result: PGresult, encoding: str = "utf-8") -> Error:
     )
 
 
+def _is_pgresult(info: ErrorInfo) -> TypeGuard[PGresult]:
+    """Return True if an ErrorInfo is a PGresult instance."""
+    # PGresult is a protocol, can't use isinstance
+    return hasattr(info, "error_field")
+
+
 def _class_for_state(sqlstate: str) -> Type[Error]:
     try:
         return lookup(sqlstate)
index c54a0e43f6e8d768375c468afbc8daf0b58afb61..553aac3b01c10614ee24228ad41804bfea0bd7f6 100644 (file)
@@ -277,4 +277,22 @@ def test_pgconn_error_pickle():
         psycopg.connect("dbname=nosuchdb")
 
     exc = pickle.loads(pickle.dumps(excinfo.value))
-    assert not exc.pgconn
+    assert exc.pgconn is None
+
+
+def test_pgresult(conn):
+    with pytest.raises(e.DatabaseError) as excinfo:
+        conn.execute("select 1 from wat")
+
+    exc = excinfo.value
+    assert exc.pgresult
+    assert exc.pgresult.error_field(pq.DiagnosticField.SQLSTATE) == b"42P01"
+
+
+def test_pgresult_pickle(conn):
+    with pytest.raises(e.DatabaseError) as excinfo:
+        conn.execute("select 1 from wat")
+
+    exc = pickle.loads(pickle.dumps(excinfo.value))
+    assert exc.pgresult is None
+    assert exc.diag.sqlstate == "42P01"