From 7cff3df14a3ce09b92cdf09abb2c48984ab8b112 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sat, 11 Apr 2020 18:03:26 +1200 Subject: [PATCH] More care in None return values from libpq result functions --- psycopg3/errors.py | 3 ++- psycopg3/pq/_pq_ctypes.pyi | 8 +++++--- psycopg3/pq/misc.py | 2 +- psycopg3/pq/pq_ctypes.py | 6 +++--- tests/pq/test_pgresult.py | 42 ++++++++++++++++++++++++++++++++++++-- 5 files changed, 51 insertions(+), 10 deletions(-) diff --git a/psycopg3/errors.py b/psycopg3/errors.py index dd0207a54..a18f3c8bf 100644 --- a/psycopg3/errors.py +++ b/psycopg3/errors.py @@ -115,5 +115,6 @@ def class_for_state(sqlstate: bytes) -> Type[Error]: def error_from_result(result: "PGresult") -> Error: from psycopg3 import pq - cls = class_for_state(result.error_field(pq.DiagnosticField.SQLSTATE)) + state = result.error_field(pq.DiagnosticField.SQLSTATE) or b"" + cls = class_for_state(state) return cls(pq.error_message(result)) diff --git a/psycopg3/pq/_pq_ctypes.pyi b/psycopg3/pq/_pq_ctypes.pyi index 340e63205..ece5b7238 100644 --- a/psycopg3/pq/_pq_ctypes.pyi +++ b/psycopg3/pq/_pq_ctypes.pyi @@ -39,9 +39,14 @@ def PQprepare( arg4: int, arg5: Optional[Array[c_uint]], ) -> PGresult_struct: ... +def PQresultErrorField( + arg1: Optional[PGresult_struct], arg2: int +) -> Optional[bytes]: ... +def PQfname(arg1: Optional[PGresult_struct], arg2: int) -> Optional[bytes]: ... def PQgetvalue( arg1: Optional[PGresult_struct], arg2: int, arg3: int ) -> pointer[c_char]: ... +def PQcmdStatus(arg1: Optional[PGresult_struct]) -> Optional[bytes]: ... # fmt: off # autogenerated: start @@ -83,11 +88,9 @@ def PQdescribePrepared(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_s def PQdescribePortal(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ... def PQresultStatus(arg1: Optional[PGresult_struct]) -> int: ... def PQresultErrorMessage(arg1: Optional[PGresult_struct]) -> bytes: ... -def PQresultErrorField(arg1: Optional[PGresult_struct], arg2: int) -> bytes: ... def PQclear(arg1: Optional[PGresult_struct]) -> None: ... def PQntuples(arg1: Optional[PGresult_struct]) -> int: ... def PQnfields(arg1: Optional[PGresult_struct]) -> int: ... -def PQfname(arg1: Optional[PGresult_struct], arg2: int) -> bytes: ... def PQftable(arg1: Optional[PGresult_struct], arg2: int) -> int: ... def PQftablecol(arg1: Optional[PGresult_struct], arg2: int) -> int: ... def PQfformat(arg1: Optional[PGresult_struct], arg2: int) -> int: ... @@ -99,7 +102,6 @@ def PQgetisnull(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: . def PQgetlength(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ... def PQnparams(arg1: Optional[PGresult_struct]) -> int: ... def PQparamtype(arg1: Optional[PGresult_struct], arg2: int) -> int: ... -def PQcmdStatus(arg1: Optional[PGresult_struct]) -> bytes: ... def PQcmdTuples(arg1: Optional[PGresult_struct]) -> bytes: ... def PQoidValue(arg1: Optional[PGresult_struct]) -> int: ... def PQescapeByteaConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: pointer[c_ulong]) -> pointer[c_ubyte]: ... diff --git a/psycopg3/pq/misc.py b/psycopg3/pq/misc.py index c6b020500..e1068ffac 100644 --- a/psycopg3/pq/misc.py +++ b/psycopg3/pq/misc.py @@ -34,7 +34,7 @@ def error_message(obj: Union["PGconn", "PGresult"]) -> str: bmsg = bmsg.splitlines()[0].split(b":", 1)[-1].strip() elif isinstance(obj, pq.PGresult): - bmsg = obj.error_field(pq.DiagnosticField.MESSAGE_PRIMARY) + bmsg = obj.error_field(pq.DiagnosticField.MESSAGE_PRIMARY) or b"" if not bmsg: bmsg = obj.error_message diff --git a/psycopg3/pq/pq_ctypes.py b/psycopg3/pq/pq_ctypes.py index d0f9b3bb3..2fef75be0 100644 --- a/psycopg3/pq/pq_ctypes.py +++ b/psycopg3/pq/pq_ctypes.py @@ -417,7 +417,7 @@ class PGresult: def error_message(self) -> bytes: return impl.PQresultErrorMessage(self.pgresult_ptr) - def error_field(self, fieldcode: DiagnosticField) -> bytes: + def error_field(self, fieldcode: DiagnosticField) -> Optional[bytes]: return impl.PQresultErrorField(self.pgresult_ptr, fieldcode) @property @@ -428,7 +428,7 @@ class PGresult: def nfields(self) -> int: return impl.PQnfields(self.pgresult_ptr) - def fname(self, column_number: int) -> bytes: + def fname(self, column_number: int) -> Optional[bytes]: return impl.PQfname(self.pgresult_ptr, column_number) def ftable(self, column_number: int) -> int: @@ -476,7 +476,7 @@ class PGresult: return impl.PQparamtype(self.pgresult_ptr, param_number) @property - def command_status(self) -> bytes: + def command_status(self) -> Optional[bytes]: return impl.PQcmdStatus(self.pgresult_ptr) @property diff --git a/tests/pq/test_pgresult.py b/tests/pq/test_pgresult.py index 3c3641ac0..559e8fbfb 100644 --- a/tests/pq/test_pgresult.py +++ b/tests/pq/test_pgresult.py @@ -13,6 +13,8 @@ import pytest def test_status(pq, pgconn, command, status): res = pgconn.exec_(command) assert res.status == getattr(pq.ExecStatus, status) + res.clear() + assert res.status == pq.ExecStatus.FATAL_ERROR def test_error_message(pgconn): @@ -20,6 +22,8 @@ def test_error_message(pgconn): assert res.error_message == b"" res = pgconn.exec_(b"select wat") assert b"wat" in res.error_message + res.clear() + assert res.error_message == b"" def test_error_field(pq, pgconn): @@ -27,6 +31,8 @@ def test_error_field(pq, pgconn): assert res.error_field(pq.DiagnosticField.SEVERITY) == b"ERROR" assert res.error_field(pq.DiagnosticField.SQLSTATE) == b"42703" assert b"wat" in res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY) + res.clear() + assert res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY) is None @pytest.mark.parametrize("n", range(4)) @@ -35,12 +41,16 @@ def test_ntuples(pgconn, n): b"select generate_series(1, $1)", [str(n).encode("ascii")] ) assert res.ntuples == n + res.clear() + assert res.ntuples == 0 def test_nfields(pgconn): + res = pgconn.exec_(b"select wat") + assert res.nfields == 0 res = pgconn.exec_(b"select 1, 2, 3") assert res.nfields == 3 - res = pgconn.exec_(b"select wat") + res.clear() assert res.nfields == 0 @@ -48,6 +58,10 @@ def test_fname(pgconn): res = pgconn.exec_(b'select 1 as foo, 2 as "BAR"') assert res.fname(0) == b"foo" assert res.fname(1) == b"BAR" + assert res.fname(2) is None + assert res.fname(-1) is None + res.clear() + assert res.fname(0) is None def test_ftable_and_col(pq, pgconn): @@ -69,6 +83,9 @@ def test_ftable_and_col(pq, pgconn): assert res.ftable(1) == int(res.get_value(0, 3).decode("ascii")) assert res.ftablecol(0) == 1 assert res.ftablecol(1) == 2 + res.clear() + assert res.ftable(0) == 0 + assert res.ftablecol(0) == 0 @pytest.mark.parametrize("fmt", (0, 1)) @@ -77,6 +94,9 @@ def test_fformat(pq, pgconn, fmt): assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message assert res.fformat(0) == fmt assert res.binary_tuples == fmt + res.clear() + assert res.fformat(0) == 0 + assert res.binary_tuples == 0 def test_ftype(pq, pgconn): @@ -85,6 +105,8 @@ def test_ftype(pq, pgconn): assert res.ftype(0) == 23 assert res.ftype(1) == 1700 assert res.ftype(2) == 25 + res.clear() + assert res.ftype(0) == 0 def test_fmod(pq, pgconn): @@ -93,6 +115,8 @@ def test_fmod(pq, pgconn): assert res.fmod(0) == -1 assert res.fmod(1) == 0xA0004 assert res.fmod(2) == 0xA0006 + res.clear() + assert res.fmod(0) == 0 def test_fsize(pq, pgconn): @@ -101,6 +125,8 @@ def test_fsize(pq, pgconn): assert res.fsize(0) == 4 assert res.fsize(1) == 8 assert res.fsize(2) == -1 + res.clear() + assert res.fsize(0) == 0 def test_get_value(pq, pgconn): @@ -109,6 +135,8 @@ def test_get_value(pq, pgconn): assert res.get_value(0, 0) == b"a" assert res.get_value(0, 1) == b"" assert res.get_value(0, 2) is None + res.clear() + assert res.get_value(0, 0) is None def test_nparams_types(pq, pgconn): @@ -122,21 +150,31 @@ def test_nparams_types(pq, pgconn): assert res.param_type(0) == 23 assert res.param_type(1) == 25 + res.clear() + assert res.nparams == 0 + assert res.param_type(0) == 0 + def test_command_status(pq, pgconn): res = pgconn.exec_(b"select 1") assert res.command_status == b"SELECT 1" res = pgconn.exec_(b"set timezone to utf8") assert res.command_status == b"SET" + res.clear() + assert res.command_status is None def test_command_tuples(pq, pgconn): + res = pgconn.exec_(b"set timezone to utf8") + assert res.command_tuples is None res = pgconn.exec_(b"select * from generate_series(1, 10)") assert res.command_tuples == 10 - res = pgconn.exec_(b"set timezone to utf8") + res.clear() assert res.command_tuples is None def test_oid_value(pq, pgconn): res = pgconn.exec_(b"select 1") assert res.oid_value == 0 + res.clear() + assert res.oid_value == 0 -- 2.47.2