From 8857076cee093620f63b498d4d24df0ad196cf61 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Fri, 22 May 2020 16:51:59 +1200 Subject: [PATCH] Added Diagnostic object Use Diagnostic as exception .diag attribute and pgresult wrapper in notices. --- psycopg3/connection.py | 8 ++- psycopg3/cursor.py | 24 ++++---- psycopg3/errors.py | 103 +++++++++++++++++++++++++++++++-- tests/pq/test_pgconn.py | 24 +------- tests/test_async_connection.py | 39 ++++--------- tests/test_connection.py | 39 ++++--------- tests/test_errors.py | 73 +++++++++++++++++++++++ 7 files changed, 216 insertions(+), 94 deletions(-) create mode 100644 tests/test_errors.py diff --git a/psycopg3/connection.py b/psycopg3/connection.py index 63aceb5dc..27b147362 100644 --- a/psycopg3/connection.py +++ b/psycopg3/connection.py @@ -39,7 +39,7 @@ else: connect = generators.connect execute = generators.execute -NoticeCallback = Callable[[pq.proto.PGresult], None] +NoticeCallback = Callable[[e.Diagnostic], None] class BaseConnection: @@ -151,11 +151,13 @@ class BaseConnection: wself: "ReferenceType[BaseConnection]", res: pq.proto.PGresult ) -> None: self = wself() - if self is None: + if self is None or not self._notice_callback: return + + diag = e.Diagnostic(res, self.codec.name) for cb in self._notice_callbacks: try: - cb(res) + cb(diag) except Exception as ex: package_logger.exception( "error processing notice callback '%s': %s", cb, ex diff --git a/psycopg3/cursor.py b/psycopg3/cursor.py index 867515e39..3a81e5f62 100644 --- a/psycopg3/cursor.py +++ b/psycopg3/cursor.py @@ -4,7 +4,6 @@ psycopg3 cursor objects # Copyright (C) 2020 The Psycopg Team -import codecs from operator import attrgetter from typing import Any, Callable, List, Optional, Sequence, TYPE_CHECKING @@ -31,12 +30,10 @@ else: class Column(Sequence[Any]): - def __init__( - self, pgresult: pq.proto.PGresult, index: int, codec: codecs.CodecInfo - ): + def __init__(self, pgresult: pq.proto.PGresult, index: int, encoding: str): self._pgresult = pgresult self._index = index - self._codec = codec + self._encoding = encoding _attrs = tuple( map( @@ -57,7 +54,7 @@ class Column(Sequence[Any]): def name(self) -> str: rv = self._pgresult.fname(self._index) if rv is not None: - return self._codec.decode(rv)[0] + return rv.decode(self._encoding) else: raise e.InterfaceError( f"no name available for column {self._index}" @@ -117,7 +114,8 @@ class BaseCursor: if res is None or res.status != self.ExecStatus.TUPLES_OK: return None return [ - Column(res, i, self.connection.codec) for i in range(res.nfields) + Column(res, i, self.connection.codec.name) + for i in range(res.nfields) ] @property @@ -196,7 +194,9 @@ class BaseCursor: return if results[-1].status == S.FATAL_ERROR: - raise e.error_from_result(results[-1]) + raise e.error_from_result( + results[-1], encoding=self.connection.codec.name + ) elif badstats & {S.COPY_IN, S.COPY_OUT, S.COPY_BOTH}: raise e.ProgrammingError( @@ -283,7 +283,9 @@ class Cursor(BaseCursor): gen = execute(self.connection.pgconn) (result,) = self.connection.wait(gen) if result.status == self.ExecStatus.FATAL_ERROR: - raise e.error_from_result(result) + raise e.error_from_result( + result, encoding=self.connection.codec.name + ) else: pgq.dump(vars) @@ -373,7 +375,9 @@ class AsyncCursor(BaseCursor): gen = execute(self.connection.pgconn) (result,) = await self.connection.wait(gen) if result.status == self.ExecStatus.FATAL_ERROR: - raise e.error_from_result(result) + raise e.error_from_result( + result, encoding=self.connection.codec.name + ) else: pgq.dump(vars) diff --git a/psycopg3/errors.py b/psycopg3/errors.py index cd95e2957..3f4d8c0fe 100644 --- a/psycopg3/errors.py +++ b/psycopg3/errors.py @@ -20,6 +20,7 @@ DBAPI-defined Exceptions are defined in the following hierarchy:: from typing import Any, Optional, Sequence, Type from psycopg3.pq.proto import PGresult +from psycopg3.pq.enums import DiagnosticField class Warning(Exception): @@ -36,10 +37,18 @@ class Error(Exception): """ def __init__( - self, *args: Sequence[Any], pgresult: Optional[PGresult] = None + self, + *args: Sequence[Any], + pgresult: Optional[PGresult] = None, + encoding: str = "utf-8" ): super().__init__(*args) self.pgresult = pgresult + self._encoding = encoding + + @property + def diag(self) -> "Diagnostic": + return Diagnostic(self.pgresult, encoding=self._encoding) class InterfaceError(Error): @@ -105,14 +114,100 @@ class NotSupportedError(DatabaseError): """ +class Diagnostic: + def __init__(self, pgresult: Optional[PGresult], encoding: str = "utf-8"): + self.pgresult = pgresult + self.encoding = encoding + + @property + def severity(self) -> Optional[str]: + return self._error_message(DiagnosticField.SEVERITY) + + @property + def severity_nonlocalized(self) -> Optional[str]: + return self._error_message(DiagnosticField.SEVERITY_NONLOCALIZED) + + @property + def sqlstate(self) -> Optional[str]: + return self._error_message(DiagnosticField.SQLSTATE) + + @property + def message_primary(self) -> Optional[str]: + return self._error_message(DiagnosticField.MESSAGE_PRIMARY) + + @property + def message_detail(self) -> Optional[str]: + return self._error_message(DiagnosticField.MESSAGE_DETAIL) + + @property + def message_hint(self) -> Optional[str]: + return self._error_message(DiagnosticField.MESSAGE_HINT) + + @property + def statement_position(self) -> Optional[str]: + return self._error_message(DiagnosticField.STATEMENT_POSITION) + + @property + def internal_position(self) -> Optional[str]: + return self._error_message(DiagnosticField.INTERNAL_POSITION) + + @property + def internal_query(self) -> Optional[str]: + return self._error_message(DiagnosticField.INTERNAL_QUERY) + + @property + def context(self) -> Optional[str]: + return self._error_message(DiagnosticField.CONTEXT) + + @property + def schema_name(self) -> Optional[str]: + return self._error_message(DiagnosticField.SCHEMA_NAME) + + @property + def table_name(self) -> Optional[str]: + return self._error_message(DiagnosticField.TABLE_NAME) + + @property + def column_name(self) -> Optional[str]: + return self._error_message(DiagnosticField.COLUMN_NAME) + + @property + def datatype_name(self) -> Optional[str]: + return self._error_message(DiagnosticField.DATATYPE_NAME) + + @property + def constraint_name(self) -> Optional[str]: + return self._error_message(DiagnosticField.CONSTRAINT_NAME) + + @property + def source_file(self) -> Optional[str]: + return self._error_message(DiagnosticField.SOURCE_FILE) + + @property + def source_line(self) -> Optional[str]: + return self._error_message(DiagnosticField.SOURCE_LINE) + + @property + def source_function(self) -> Optional[str]: + return self._error_message(DiagnosticField.SOURCE_FUNCTION) + + def _error_message(self, field: DiagnosticField) -> Optional[str]: + if self.pgresult is not None: + val = self.pgresult.error_field(field) + if val is not None: + return val.decode(self.encoding, "replace") + + return None + + def class_for_state(sqlstate: bytes) -> Type[Error]: # TODO: stub return DatabaseError -def error_from_result(result: PGresult) -> Error: +def error_from_result(result: PGresult, encoding: str = "utf-8") -> Error: from psycopg3 import pq - state = result.error_field(pq.DiagnosticField.SQLSTATE) or b"" + state = result.error_field(DiagnosticField.SQLSTATE) or b"" cls = class_for_state(state) - return cls(pq.error_message(result)) + return cls(pq.error_message(result), pgresult=result, encoding=encoding) diff --git a/tests/pq/test_pgconn.py b/tests/pq/test_pgconn.py index 669939b2a..27034f0aa 100644 --- a/tests/pq/test_pgconn.py +++ b/tests/pq/test_pgconn.py @@ -333,13 +333,7 @@ def test_make_empty_result(pq, pgconn): def test_notice_nohandler(pq, pgconn): res = pgconn.exec_( - b""" -do $$ -begin - raise notice 'hello notice'; -end -$$ language plpgsql - """ + b"do $$begin raise notice 'hello notice'; end$$ language plpgsql" ) assert res.status == pq.ExecStatus.COMMAND_OK @@ -353,13 +347,7 @@ def test_notice(pq, pgconn): pgconn.notice_callback = callback res = pgconn.exec_( - b""" -do $$ -begin - raise notice 'hello notice'; -end -$$ language plpgsql - """ + b"do $$begin raise notice 'hello notice'; end$$ language plpgsql" ) assert res.status == pq.ExecStatus.COMMAND_OK @@ -374,13 +362,7 @@ def test_notice_error(pq, pgconn, caplog): pgconn.notice_callback = callback res = pgconn.exec_( - b""" -do $$ -begin - raise notice 'hello notice'; -end -$$ language plpgsql - """ + b"do $$begin raise notice 'hello notice'; end$$ language plpgsql" ) assert res.status == pq.ExecStatus.COMMAND_OK diff --git a/tests/test_async_connection.py b/tests/test_async_connection.py index 31c304406..8827484ab 100644 --- a/tests/test_async_connection.py +++ b/tests/test_async_connection.py @@ -223,38 +223,27 @@ def test_notice_callbacks(aconn, loop, caplog): messages = [] severities = [] - def cb1(res): - messages.append( - res.error_field(psycopg3.pq.DiagnosticField.MESSAGE_PRIMARY) - ) + def cb1(diag): + messages.append(diag.message_primary) def cb2(res): raise Exception("hello from cb2") - def cb3(res): - severities.append( - res.error_field(psycopg3.pq.DiagnosticField.SEVERITY_NONLOCALIZED) - ) - aconn.add_notice_callback(cb1) aconn.add_notice_callback(cb2) aconn.add_notice_callback("the wrong thing") - aconn.add_notice_callback(cb3) + aconn.add_notice_callback( + lambda diag: severities.append(diag.severity_nonlocalized) + ) cur = aconn.cursor() loop.run_until_complete( cur.execute( - """ -do $$ -begin - raise notice 'hello notice'; -end -$$ language plpgsql - """ + "do $$begin raise notice 'hello notice'; end$$ language plpgsql" ) ) - assert messages == [b"hello notice"] - assert severities == [b"NOTICE"] + assert messages == ["hello notice"] + assert severities == ["NOTICE"] assert len(caplog.records) == 2 rec = caplog.records[0] @@ -268,18 +257,12 @@ $$ language plpgsql aconn.remove_notice_callback("the wrong thing") loop.run_until_complete( cur.execute( - """ -do $$ -begin - raise warning 'hello warning'; -end -$$ language plpgsql - """ + "do $$begin raise warning 'hello warning'; end$$ language plpgsql" ) ) assert len(caplog.records) == 3 - assert messages == [b"hello notice"] - assert severities == [b"NOTICE", b"WARNING"] + assert messages == ["hello notice"] + assert severities == ["NOTICE", "WARNING"] with pytest.raises(ValueError): aconn.remove_notice_callback(cb1) diff --git a/tests/test_connection.py b/tests/test_connection.py index c1dce1bbc..830f94145 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -213,36 +213,25 @@ def test_notice_callbacks(conn, caplog): messages = [] severities = [] - def cb1(res): - messages.append( - res.error_field(psycopg3.pq.DiagnosticField.MESSAGE_PRIMARY) - ) + def cb1(diag): + messages.append(diag.message_primary) def cb2(res): raise Exception("hello from cb2") - def cb3(res): - severities.append( - res.error_field(psycopg3.pq.DiagnosticField.SEVERITY_NONLOCALIZED) - ) - conn.add_notice_callback(cb1) conn.add_notice_callback(cb2) conn.add_notice_callback("the wrong thing") - conn.add_notice_callback(cb3) + conn.add_notice_callback( + lambda diag: severities.append(diag.severity_nonlocalized) + ) cur = conn.cursor() cur.execute( - """ -do $$ -begin - raise notice 'hello notice'; -end -$$ language plpgsql - """ + "do $$begin raise notice 'hello notice'; end$$ language plpgsql" ) - assert messages == [b"hello notice"] - assert severities == [b"NOTICE"] + assert messages == ["hello notice"] + assert severities == ["NOTICE"] assert len(caplog.records) == 2 rec = caplog.records[0] @@ -255,17 +244,11 @@ $$ language plpgsql conn.remove_notice_callback(cb1) conn.remove_notice_callback("the wrong thing") cur.execute( - """ -do $$ -begin - raise warning 'hello warning'; -end -$$ language plpgsql - """ + "do $$begin raise warning 'hello warning'; end$$ language plpgsql" ) assert len(caplog.records) == 3 - assert messages == [b"hello notice"] - assert severities == [b"NOTICE", b"WARNING"] + assert messages == ["hello notice"] + assert severities == ["NOTICE", "WARNING"] with pytest.raises(ValueError): conn.remove_notice_callback(cb1) diff --git a/tests/test_errors.py b/tests/test_errors.py new file mode 100644 index 000000000..dc18e7985 --- /dev/null +++ b/tests/test_errors.py @@ -0,0 +1,73 @@ +import pytest +from psycopg3 import errors as e + +eur = "\u20ac" + + +def test_error_diag(conn): + cur = conn.cursor() + with pytest.raises(e.DatabaseError) as excinfo: + cur.execute("select 1 from wat") + + exc = excinfo.value + diag = exc.diag + assert diag.sqlstate == "42P01" + assert diag.severity_nonlocalized == "ERROR" + + +def test_diag_all_attrs(pgconn, pq): + res = pgconn.make_empty_result(pq.ExecStatus.NONFATAL_ERROR) + diag = e.Diagnostic(res) + for d in pq.DiagnosticField: + val = getattr(diag, d.name.lower()) + assert val is None or isinstance(val, str) + + +def test_diag_right_attr(pgconn, pq, monkeypatch): + res = pgconn.make_empty_result(pq.ExecStatus.NONFATAL_ERROR) + diag = e.Diagnostic(res) + + checked = [] + + def check_val(self, v): + nonlocal to_check + assert to_check == v + checked.append(v) + return None + + monkeypatch.setattr(e.Diagnostic, "_error_message", check_val) + + for to_check in pq.DiagnosticField: + getattr(diag, to_check.name.lower()) + + assert len(checked) == len(pq.DiagnosticField) + + +@pytest.mark.parametrize("enc", ["utf8", "latin9"]) +def test_diag_encoding(conn, enc): + msgs = [] + conn.add_notice_callback(lambda diag: msgs.append(diag.message_primary)) + conn.set_client_encoding(enc) + cur = conn.cursor() + cur.execute( + "do $$begin raise notice 'hello %', chr(8364); end$$ language plpgsql" + ) + assert msgs == [f"hello {eur}"] + + +@pytest.mark.parametrize("enc", ["utf8", "latin9"]) +def test_error_encoding(conn, enc): + conn.set_client_encoding(enc) + cur = conn.cursor() + with pytest.raises(e.DatabaseError) as excinfo: + cur.execute( + """ + do $$begin + execute format('insert into "%s" values (1)', chr(8364)); + end$$ language plpgsql; + """ + ) + + diag = excinfo.value.diag + assert f'"{eur}"' in diag.message_primary + assert diag.sqlstate == "42P01" -- 2.47.2