]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
More care in None return values from libpq result functions
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Apr 2020 06:03:26 +0000 (18:03 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Apr 2020 06:03:26 +0000 (18:03 +1200)
psycopg3/errors.py
psycopg3/pq/_pq_ctypes.pyi
psycopg3/pq/misc.py
psycopg3/pq/pq_ctypes.py
tests/pq/test_pgresult.py

index dd0207a54f40f508217ff85043de0f6ac692ea2b..a18f3c8bfb9ba86e5ad4b24b246218d559a50e93 100644 (file)
@@ -115,5 +115,6 @@ def class_for_state(sqlstate: bytes) -> Type[Error]:
 def error_from_result(result: "PGresult") -> Error:
     from psycopg3 import pq
 
-    cls = class_for_state(result.error_field(pq.DiagnosticField.SQLSTATE))
+    state = result.error_field(pq.DiagnosticField.SQLSTATE) or b""
+    cls = class_for_state(state)
     return cls(pq.error_message(result))
index 340e632058310772cfc5242b51088daa7ab131e3..ece5b72384ae3f059a5c9a553204b02bf81bc1b9 100644 (file)
@@ -39,9 +39,14 @@ def PQprepare(
     arg4: int,
     arg5: Optional[Array[c_uint]],
 ) -> PGresult_struct: ...
+def PQresultErrorField(
+    arg1: Optional[PGresult_struct], arg2: int
+) -> Optional[bytes]: ...
+def PQfname(arg1: Optional[PGresult_struct], arg2: int) -> Optional[bytes]: ...
 def PQgetvalue(
     arg1: Optional[PGresult_struct], arg2: int, arg3: int
 ) -> pointer[c_char]: ...
+def PQcmdStatus(arg1: Optional[PGresult_struct]) -> Optional[bytes]: ...
 
 # fmt: off
 # autogenerated: start
@@ -83,11 +88,9 @@ def PQdescribePrepared(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_s
 def PQdescribePortal(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ...
 def PQresultStatus(arg1: Optional[PGresult_struct]) -> int: ...
 def PQresultErrorMessage(arg1: Optional[PGresult_struct]) -> bytes: ...
-def PQresultErrorField(arg1: Optional[PGresult_struct], arg2: int) -> bytes: ...
 def PQclear(arg1: Optional[PGresult_struct]) -> None: ...
 def PQntuples(arg1: Optional[PGresult_struct]) -> int: ...
 def PQnfields(arg1: Optional[PGresult_struct]) -> int: ...
-def PQfname(arg1: Optional[PGresult_struct], arg2: int) -> bytes: ...
 def PQftable(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
 def PQftablecol(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
 def PQfformat(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
@@ -99,7 +102,6 @@ def PQgetisnull(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: .
 def PQgetlength(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ...
 def PQnparams(arg1: Optional[PGresult_struct]) -> int: ...
 def PQparamtype(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
-def PQcmdStatus(arg1: Optional[PGresult_struct]) -> bytes: ...
 def PQcmdTuples(arg1: Optional[PGresult_struct]) -> bytes: ...
 def PQoidValue(arg1: Optional[PGresult_struct]) -> int: ...
 def PQescapeByteaConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: pointer[c_ulong]) -> pointer[c_ubyte]: ...
index c6b020500c5fa7a9a9a2de5b49b55163116fe442..e1068ffacdedd9acbabd5a3af6b28b9c9b868d47 100644 (file)
@@ -34,7 +34,7 @@ def error_message(obj: Union["PGconn", "PGresult"]) -> str:
             bmsg = bmsg.splitlines()[0].split(b":", 1)[-1].strip()
 
     elif isinstance(obj, pq.PGresult):
-        bmsg = obj.error_field(pq.DiagnosticField.MESSAGE_PRIMARY)
+        bmsg = obj.error_field(pq.DiagnosticField.MESSAGE_PRIMARY) or b""
         if not bmsg:
             bmsg = obj.error_message
 
index d0f9b3bb3640a5383ee15a307ddc3ad0d9fbd043..2fef75be0e62c7c9e9fc26311c29d619609d44ee 100644 (file)
@@ -417,7 +417,7 @@ class PGresult:
     def error_message(self) -> bytes:
         return impl.PQresultErrorMessage(self.pgresult_ptr)
 
-    def error_field(self, fieldcode: DiagnosticField) -> bytes:
+    def error_field(self, fieldcode: DiagnosticField) -> Optional[bytes]:
         return impl.PQresultErrorField(self.pgresult_ptr, fieldcode)
 
     @property
@@ -428,7 +428,7 @@ class PGresult:
     def nfields(self) -> int:
         return impl.PQnfields(self.pgresult_ptr)
 
-    def fname(self, column_number: int) -> bytes:
+    def fname(self, column_number: int) -> Optional[bytes]:
         return impl.PQfname(self.pgresult_ptr, column_number)
 
     def ftable(self, column_number: int) -> int:
@@ -476,7 +476,7 @@ class PGresult:
         return impl.PQparamtype(self.pgresult_ptr, param_number)
 
     @property
-    def command_status(self) -> bytes:
+    def command_status(self) -> Optional[bytes]:
         return impl.PQcmdStatus(self.pgresult_ptr)
 
     @property
index 3c3641ac04d3b13ee6203b55321d5aa7d846a53d..559e8fbfbcfabc9a22f5a0bc4c44c13021f0203a 100644 (file)
@@ -13,6 +13,8 @@ import pytest
 def test_status(pq, pgconn, command, status):
     res = pgconn.exec_(command)
     assert res.status == getattr(pq.ExecStatus, status)
+    res.clear()
+    assert res.status == pq.ExecStatus.FATAL_ERROR
 
 
 def test_error_message(pgconn):
@@ -20,6 +22,8 @@ def test_error_message(pgconn):
     assert res.error_message == b""
     res = pgconn.exec_(b"select wat")
     assert b"wat" in res.error_message
+    res.clear()
+    assert res.error_message == b""
 
 
 def test_error_field(pq, pgconn):
@@ -27,6 +31,8 @@ def test_error_field(pq, pgconn):
     assert res.error_field(pq.DiagnosticField.SEVERITY) == b"ERROR"
     assert res.error_field(pq.DiagnosticField.SQLSTATE) == b"42703"
     assert b"wat" in res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY)
+    res.clear()
+    assert res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY) is None
 
 
 @pytest.mark.parametrize("n", range(4))
@@ -35,12 +41,16 @@ def test_ntuples(pgconn, n):
         b"select generate_series(1, $1)", [str(n).encode("ascii")]
     )
     assert res.ntuples == n
+    res.clear()
+    assert res.ntuples == 0
 
 
 def test_nfields(pgconn):
+    res = pgconn.exec_(b"select wat")
+    assert res.nfields == 0
     res = pgconn.exec_(b"select 1, 2, 3")
     assert res.nfields == 3
-    res = pgconn.exec_(b"select wat")
+    res.clear()
     assert res.nfields == 0
 
 
@@ -48,6 +58,10 @@ def test_fname(pgconn):
     res = pgconn.exec_(b'select 1 as foo, 2 as "BAR"')
     assert res.fname(0) == b"foo"
     assert res.fname(1) == b"BAR"
+    assert res.fname(2) is None
+    assert res.fname(-1) is None
+    res.clear()
+    assert res.fname(0) is None
 
 
 def test_ftable_and_col(pq, pgconn):
@@ -69,6 +83,9 @@ def test_ftable_and_col(pq, pgconn):
     assert res.ftable(1) == int(res.get_value(0, 3).decode("ascii"))
     assert res.ftablecol(0) == 1
     assert res.ftablecol(1) == 2
+    res.clear()
+    assert res.ftable(0) == 0
+    assert res.ftablecol(0) == 0
 
 
 @pytest.mark.parametrize("fmt", (0, 1))
@@ -77,6 +94,9 @@ def test_fformat(pq, pgconn, fmt):
     assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
     assert res.fformat(0) == fmt
     assert res.binary_tuples == fmt
+    res.clear()
+    assert res.fformat(0) == 0
+    assert res.binary_tuples == 0
 
 
 def test_ftype(pq, pgconn):
@@ -85,6 +105,8 @@ def test_ftype(pq, pgconn):
     assert res.ftype(0) == 23
     assert res.ftype(1) == 1700
     assert res.ftype(2) == 25
+    res.clear()
+    assert res.ftype(0) == 0
 
 
 def test_fmod(pq, pgconn):
@@ -93,6 +115,8 @@ def test_fmod(pq, pgconn):
     assert res.fmod(0) == -1
     assert res.fmod(1) == 0xA0004
     assert res.fmod(2) == 0xA0006
+    res.clear()
+    assert res.fmod(0) == 0
 
 
 def test_fsize(pq, pgconn):
@@ -101,6 +125,8 @@ def test_fsize(pq, pgconn):
     assert res.fsize(0) == 4
     assert res.fsize(1) == 8
     assert res.fsize(2) == -1
+    res.clear()
+    assert res.fsize(0) == 0
 
 
 def test_get_value(pq, pgconn):
@@ -109,6 +135,8 @@ def test_get_value(pq, pgconn):
     assert res.get_value(0, 0) == b"a"
     assert res.get_value(0, 1) == b""
     assert res.get_value(0, 2) is None
+    res.clear()
+    assert res.get_value(0, 0) is None
 
 
 def test_nparams_types(pq, pgconn):
@@ -122,21 +150,31 @@ def test_nparams_types(pq, pgconn):
     assert res.param_type(0) == 23
     assert res.param_type(1) == 25
 
+    res.clear()
+    assert res.nparams == 0
+    assert res.param_type(0) == 0
+
 
 def test_command_status(pq, pgconn):
     res = pgconn.exec_(b"select 1")
     assert res.command_status == b"SELECT 1"
     res = pgconn.exec_(b"set timezone to utf8")
     assert res.command_status == b"SET"
+    res.clear()
+    assert res.command_status is None
 
 
 def test_command_tuples(pq, pgconn):
+    res = pgconn.exec_(b"set timezone to utf8")
+    assert res.command_tuples is None
     res = pgconn.exec_(b"select * from generate_series(1, 10)")
     assert res.command_tuples == 10
-    res = pgconn.exec_(b"set timezone to utf8")
+    res.clear()
     assert res.command_tuples is None
 
 
 def test_oid_value(pq, pgconn):
     res = pgconn.exec_(b"select 1")
     assert res.oid_value == 0
+    res.clear()
+    assert res.oid_value == 0