]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added Diagnostic object
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 22 May 2020 04:51:59 +0000 (16:51 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 22 May 2020 04:51:59 +0000 (16:51 +1200)
Use Diagnostic as exception .diag attribute and pgresult wrapper in
notices.

psycopg3/connection.py
psycopg3/cursor.py
psycopg3/errors.py
tests/pq/test_pgconn.py
tests/test_async_connection.py
tests/test_connection.py
tests/test_errors.py [new file with mode: 0644]

index 63aceb5dc9960dd781d7abe8193a980e092e0639..27b1473629f26f12b9a3f5c89f2a42745096ae54 100644 (file)
@@ -39,7 +39,7 @@ else:
     connect = generators.connect
     execute = generators.execute
 
-NoticeCallback = Callable[[pq.proto.PGresult], None]
+NoticeCallback = Callable[[e.Diagnostic], None]
 
 
 class BaseConnection:
@@ -151,11 +151,13 @@ class BaseConnection:
         wself: "ReferenceType[BaseConnection]", res: pq.proto.PGresult
     ) -> None:
         self = wself()
-        if self is None:
+        if self is None or not self._notice_callback:
             return
+
+        diag = e.Diagnostic(res, self.codec.name)
         for cb in self._notice_callbacks:
             try:
-                cb(res)
+                cb(diag)
             except Exception as ex:
                 package_logger.exception(
                     "error processing notice callback '%s': %s", cb, ex
index 867515e39b71183fc5b9ecb5e6e91139146cba67..3a81e5f62214adac8855943f2160043e1e894400 100644 (file)
@@ -4,7 +4,6 @@ psycopg3 cursor objects
 
 # Copyright (C) 2020 The Psycopg Team
 
-import codecs
 from operator import attrgetter
 from typing import Any, Callable, List, Optional, Sequence, TYPE_CHECKING
 
@@ -31,12 +30,10 @@ else:
 
 
 class Column(Sequence[Any]):
-    def __init__(
-        self, pgresult: pq.proto.PGresult, index: int, codec: codecs.CodecInfo
-    ):
+    def __init__(self, pgresult: pq.proto.PGresult, index: int, encoding: str):
         self._pgresult = pgresult
         self._index = index
-        self._codec = codec
+        self._encoding = encoding
 
     _attrs = tuple(
         map(
@@ -57,7 +54,7 @@ class Column(Sequence[Any]):
     def name(self) -> str:
         rv = self._pgresult.fname(self._index)
         if rv is not None:
-            return self._codec.decode(rv)[0]
+            return rv.decode(self._encoding)
         else:
             raise e.InterfaceError(
                 f"no name available for column {self._index}"
@@ -117,7 +114,8 @@ class BaseCursor:
         if res is None or res.status != self.ExecStatus.TUPLES_OK:
             return None
         return [
-            Column(res, i, self.connection.codec) for i in range(res.nfields)
+            Column(res, i, self.connection.codec.name)
+            for i in range(res.nfields)
         ]
 
     @property
@@ -196,7 +194,9 @@ class BaseCursor:
             return
 
         if results[-1].status == S.FATAL_ERROR:
-            raise e.error_from_result(results[-1])
+            raise e.error_from_result(
+                results[-1], encoding=self.connection.codec.name
+            )
 
         elif badstats & {S.COPY_IN, S.COPY_OUT, S.COPY_BOTH}:
             raise e.ProgrammingError(
@@ -283,7 +283,9 @@ class Cursor(BaseCursor):
                     gen = execute(self.connection.pgconn)
                     (result,) = self.connection.wait(gen)
                     if result.status == self.ExecStatus.FATAL_ERROR:
-                        raise e.error_from_result(result)
+                        raise e.error_from_result(
+                            result, encoding=self.connection.codec.name
+                        )
                 else:
                     pgq.dump(vars)
 
@@ -373,7 +375,9 @@ class AsyncCursor(BaseCursor):
                     gen = execute(self.connection.pgconn)
                     (result,) = await self.connection.wait(gen)
                     if result.status == self.ExecStatus.FATAL_ERROR:
-                        raise e.error_from_result(result)
+                        raise e.error_from_result(
+                            result, encoding=self.connection.codec.name
+                        )
                 else:
                     pgq.dump(vars)
 
index cd95e2957931e56678f4514ea931ad735a8127a6..3f4d8c0fe78f0666c2ec9d68e87076c10f31d915 100644 (file)
@@ -20,6 +20,7 @@ DBAPI-defined Exceptions are defined in the following hierarchy::
 
 from typing import Any, Optional, Sequence, Type
 from psycopg3.pq.proto import PGresult
+from psycopg3.pq.enums import DiagnosticField
 
 
 class Warning(Exception):
@@ -36,10 +37,18 @@ class Error(Exception):
     """
 
     def __init__(
-        self, *args: Sequence[Any], pgresult: Optional[PGresult] = None
+        self,
+        *args: Sequence[Any],
+        pgresult: Optional[PGresult] = None,
+        encoding: str = "utf-8"
     ):
         super().__init__(*args)
         self.pgresult = pgresult
+        self._encoding = encoding
+
+    @property
+    def diag(self) -> "Diagnostic":
+        return Diagnostic(self.pgresult, encoding=self._encoding)
 
 
 class InterfaceError(Error):
@@ -105,14 +114,100 @@ class NotSupportedError(DatabaseError):
     """
 
 
+class Diagnostic:
+    def __init__(self, pgresult: Optional[PGresult], encoding: str = "utf-8"):
+        self.pgresult = pgresult
+        self.encoding = encoding
+
+    @property
+    def severity(self) -> Optional[str]:
+        return self._error_message(DiagnosticField.SEVERITY)
+
+    @property
+    def severity_nonlocalized(self) -> Optional[str]:
+        return self._error_message(DiagnosticField.SEVERITY_NONLOCALIZED)
+
+    @property
+    def sqlstate(self) -> Optional[str]:
+        return self._error_message(DiagnosticField.SQLSTATE)
+
+    @property
+    def message_primary(self) -> Optional[str]:
+        return self._error_message(DiagnosticField.MESSAGE_PRIMARY)
+
+    @property
+    def message_detail(self) -> Optional[str]:
+        return self._error_message(DiagnosticField.MESSAGE_DETAIL)
+
+    @property
+    def message_hint(self) -> Optional[str]:
+        return self._error_message(DiagnosticField.MESSAGE_HINT)
+
+    @property
+    def statement_position(self) -> Optional[str]:
+        return self._error_message(DiagnosticField.STATEMENT_POSITION)
+
+    @property
+    def internal_position(self) -> Optional[str]:
+        return self._error_message(DiagnosticField.INTERNAL_POSITION)
+
+    @property
+    def internal_query(self) -> Optional[str]:
+        return self._error_message(DiagnosticField.INTERNAL_QUERY)
+
+    @property
+    def context(self) -> Optional[str]:
+        return self._error_message(DiagnosticField.CONTEXT)
+
+    @property
+    def schema_name(self) -> Optional[str]:
+        return self._error_message(DiagnosticField.SCHEMA_NAME)
+
+    @property
+    def table_name(self) -> Optional[str]:
+        return self._error_message(DiagnosticField.TABLE_NAME)
+
+    @property
+    def column_name(self) -> Optional[str]:
+        return self._error_message(DiagnosticField.COLUMN_NAME)
+
+    @property
+    def datatype_name(self) -> Optional[str]:
+        return self._error_message(DiagnosticField.DATATYPE_NAME)
+
+    @property
+    def constraint_name(self) -> Optional[str]:
+        return self._error_message(DiagnosticField.CONSTRAINT_NAME)
+
+    @property
+    def source_file(self) -> Optional[str]:
+        return self._error_message(DiagnosticField.SOURCE_FILE)
+
+    @property
+    def source_line(self) -> Optional[str]:
+        return self._error_message(DiagnosticField.SOURCE_LINE)
+
+    @property
+    def source_function(self) -> Optional[str]:
+        return self._error_message(DiagnosticField.SOURCE_FUNCTION)
+
+    def _error_message(self, field: DiagnosticField) -> Optional[str]:
+        if self.pgresult is not None:
+            val = self.pgresult.error_field(field)
+            if val is not None:
+                return val.decode(self.encoding, "replace")
+
+        return None
+
+
 def class_for_state(sqlstate: bytes) -> Type[Error]:
     # TODO: stub
     return DatabaseError
 
 
-def error_from_result(result: PGresult) -> Error:
+def error_from_result(result: PGresult, encoding: str = "utf-8") -> Error:
     from psycopg3 import pq
 
-    state = result.error_field(pq.DiagnosticField.SQLSTATE) or b""
+    state = result.error_field(DiagnosticField.SQLSTATE) or b""
     cls = class_for_state(state)
-    return cls(pq.error_message(result))
+    return cls(pq.error_message(result), pgresult=result, encoding=encoding)
index 669939b2a66bd421aa8b2d91114ea87730eb6868..27034f0aaf340b84dc22f682a418beb146fd7973 100644 (file)
@@ -333,13 +333,7 @@ def test_make_empty_result(pq, pgconn):
 
 def test_notice_nohandler(pq, pgconn):
     res = pgconn.exec_(
-        b"""
-do $$
-begin
-    raise notice 'hello notice';
-end
-$$ language plpgsql
-    """
+        b"do $$begin raise notice 'hello notice'; end$$ language plpgsql"
     )
     assert res.status == pq.ExecStatus.COMMAND_OK
 
@@ -353,13 +347,7 @@ def test_notice(pq, pgconn):
 
     pgconn.notice_callback = callback
     res = pgconn.exec_(
-        b"""
-do $$
-begin
-    raise notice 'hello notice';
-end
-$$ language plpgsql
-    """
+        b"do $$begin raise notice 'hello notice'; end$$ language plpgsql"
     )
 
     assert res.status == pq.ExecStatus.COMMAND_OK
@@ -374,13 +362,7 @@ def test_notice_error(pq, pgconn, caplog):
 
     pgconn.notice_callback = callback
     res = pgconn.exec_(
-        b"""
-do $$
-begin
-    raise notice 'hello notice';
-end
-$$ language plpgsql
-    """
+        b"do $$begin raise notice 'hello notice'; end$$ language plpgsql"
     )
 
     assert res.status == pq.ExecStatus.COMMAND_OK
index 31c3044062beb2139825b0dd3e9788f08928c6c7..8827484ab2ae807dd5114ea1b80ede38d67a61e5 100644 (file)
@@ -223,38 +223,27 @@ def test_notice_callbacks(aconn, loop, caplog):
     messages = []
     severities = []
 
-    def cb1(res):
-        messages.append(
-            res.error_field(psycopg3.pq.DiagnosticField.MESSAGE_PRIMARY)
-        )
+    def cb1(diag):
+        messages.append(diag.message_primary)
 
     def cb2(res):
         raise Exception("hello from cb2")
 
-    def cb3(res):
-        severities.append(
-            res.error_field(psycopg3.pq.DiagnosticField.SEVERITY_NONLOCALIZED)
-        )
-
     aconn.add_notice_callback(cb1)
     aconn.add_notice_callback(cb2)
     aconn.add_notice_callback("the wrong thing")
-    aconn.add_notice_callback(cb3)
+    aconn.add_notice_callback(
+        lambda diag: severities.append(diag.severity_nonlocalized)
+    )
 
     cur = aconn.cursor()
     loop.run_until_complete(
         cur.execute(
-            """
-do $$
-begin
-    raise notice 'hello notice';
-end
-$$ language plpgsql
-    """
+            "do $$begin raise notice 'hello notice'; end$$ language plpgsql"
         )
     )
-    assert messages == [b"hello notice"]
-    assert severities == [b"NOTICE"]
+    assert messages == ["hello notice"]
+    assert severities == ["NOTICE"]
 
     assert len(caplog.records) == 2
     rec = caplog.records[0]
@@ -268,18 +257,12 @@ $$ language plpgsql
     aconn.remove_notice_callback("the wrong thing")
     loop.run_until_complete(
         cur.execute(
-            """
-do $$
-begin
-    raise warning 'hello warning';
-end
-$$ language plpgsql
-    """
+            "do $$begin raise warning 'hello warning'; end$$ language plpgsql"
         )
     )
     assert len(caplog.records) == 3
-    assert messages == [b"hello notice"]
-    assert severities == [b"NOTICE", b"WARNING"]
+    assert messages == ["hello notice"]
+    assert severities == ["NOTICE", "WARNING"]
 
     with pytest.raises(ValueError):
         aconn.remove_notice_callback(cb1)
index c1dce1bbc296dd263a1c36d5517cafe922e62414..830f941454ac056c05b5742da542508eea7a08ae 100644 (file)
@@ -213,36 +213,25 @@ def test_notice_callbacks(conn, caplog):
     messages = []
     severities = []
 
-    def cb1(res):
-        messages.append(
-            res.error_field(psycopg3.pq.DiagnosticField.MESSAGE_PRIMARY)
-        )
+    def cb1(diag):
+        messages.append(diag.message_primary)
 
     def cb2(res):
         raise Exception("hello from cb2")
 
-    def cb3(res):
-        severities.append(
-            res.error_field(psycopg3.pq.DiagnosticField.SEVERITY_NONLOCALIZED)
-        )
-
     conn.add_notice_callback(cb1)
     conn.add_notice_callback(cb2)
     conn.add_notice_callback("the wrong thing")
-    conn.add_notice_callback(cb3)
+    conn.add_notice_callback(
+        lambda diag: severities.append(diag.severity_nonlocalized)
+    )
 
     cur = conn.cursor()
     cur.execute(
-        """
-do $$
-begin
-    raise notice 'hello notice';
-end
-$$ language plpgsql
-    """
+        "do $$begin raise notice 'hello notice'; end$$ language plpgsql"
     )
-    assert messages == [b"hello notice"]
-    assert severities == [b"NOTICE"]
+    assert messages == ["hello notice"]
+    assert severities == ["NOTICE"]
 
     assert len(caplog.records) == 2
     rec = caplog.records[0]
@@ -255,17 +244,11 @@ $$ language plpgsql
     conn.remove_notice_callback(cb1)
     conn.remove_notice_callback("the wrong thing")
     cur.execute(
-        """
-do $$
-begin
-    raise warning 'hello warning';
-end
-$$ language plpgsql
-    """
+        "do $$begin raise warning 'hello warning'; end$$ language plpgsql"
     )
     assert len(caplog.records) == 3
-    assert messages == [b"hello notice"]
-    assert severities == [b"NOTICE", b"WARNING"]
+    assert messages == ["hello notice"]
+    assert severities == ["NOTICE", "WARNING"]
 
     with pytest.raises(ValueError):
         conn.remove_notice_callback(cb1)
diff --git a/tests/test_errors.py b/tests/test_errors.py
new file mode 100644 (file)
index 0000000..dc18e79
--- /dev/null
@@ -0,0 +1,73 @@
+import pytest
+from psycopg3 import errors as e
+
+eur = "\u20ac"
+
+
+def test_error_diag(conn):
+    cur = conn.cursor()
+    with pytest.raises(e.DatabaseError) as excinfo:
+        cur.execute("select 1 from wat")
+
+    exc = excinfo.value
+    diag = exc.diag
+    assert diag.sqlstate == "42P01"
+    assert diag.severity_nonlocalized == "ERROR"
+
+
+def test_diag_all_attrs(pgconn, pq):
+    res = pgconn.make_empty_result(pq.ExecStatus.NONFATAL_ERROR)
+    diag = e.Diagnostic(res)
+    for d in pq.DiagnosticField:
+        val = getattr(diag, d.name.lower())
+        assert val is None or isinstance(val, str)
+
+
+def test_diag_right_attr(pgconn, pq, monkeypatch):
+    res = pgconn.make_empty_result(pq.ExecStatus.NONFATAL_ERROR)
+    diag = e.Diagnostic(res)
+
+    checked = []
+
+    def check_val(self, v):
+        nonlocal to_check
+        assert to_check == v
+        checked.append(v)
+        return None
+
+    monkeypatch.setattr(e.Diagnostic, "_error_message", check_val)
+
+    for to_check in pq.DiagnosticField:
+        getattr(diag, to_check.name.lower())
+
+    assert len(checked) == len(pq.DiagnosticField)
+
+
+@pytest.mark.parametrize("enc", ["utf8", "latin9"])
+def test_diag_encoding(conn, enc):
+    msgs = []
+    conn.add_notice_callback(lambda diag: msgs.append(diag.message_primary))
+    conn.set_client_encoding(enc)
+    cur = conn.cursor()
+    cur.execute(
+        "do $$begin raise notice 'hello %', chr(8364); end$$ language plpgsql"
+    )
+    assert msgs == [f"hello {eur}"]
+
+
+@pytest.mark.parametrize("enc", ["utf8", "latin9"])
+def test_error_encoding(conn, enc):
+    conn.set_client_encoding(enc)
+    cur = conn.cursor()
+    with pytest.raises(e.DatabaseError) as excinfo:
+        cur.execute(
+            """
+            do $$begin
+                execute format('insert into "%s" values (1)', chr(8364));
+            end$$ language plpgsql;
+            """
+        )
+
+    diag = excinfo.value.diag
+    assert f'"{eur}"' in diag.message_primary
+    assert diag.sqlstate == "42P01"