]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
More care with functions returning null in PGconn methods
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Apr 2020 07:48:48 +0000 (19:48 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Apr 2020 07:55:48 +0000 (19:55 +1200)
ctypes stub changed: by default the function returning bytes can return
NULL. The ones which never do it are few, so these are the ones which
get special-cased.

psycopg3/pq/_pq_ctypes.py
psycopg3/pq/_pq_ctypes.pyi
psycopg3/pq/pq_ctypes.py
tests/pq/test_pgconn.py

index 3dc48645288fbc4cf4fe555d8130bb29133cfca4..0d232466db1cce011cba14a65a47770c6b960c26 100644 (file)
@@ -458,7 +458,10 @@ def generate_stub() -> None:
         elif t is c_int or t is c_uint or t is c_size_t:
             return "int"
         elif t is c_char_p or t.__name__ == "LP_c_char":
-            return "bytes"
+            if narg is not None:
+                return "bytes"
+            else:
+                return "Optional[bytes]"
 
         elif t.__name__ in ("LP_PGconn_struct", "LP_PGresult_struct",):
             if narg is not None:
index ece5b72384ae3f059a5c9a553204b02bf81bc1b9..a47e4530e2e012a77ebf42e4a92210d5ad6fd0bc 100644 (file)
@@ -23,6 +23,8 @@ class PQconninfoOption_struct:
     dispsize: int
 
 def PQhostaddr(arg1: Optional[PGconn_struct]) -> bytes: ...
+def PQerrorMessage(arg1: Optional[PGconn_struct]) -> bytes: ...
+def PQresultErrorMessage(arg1: Optional[PGresult_struct]) -> bytes: ...
 def PQexecPrepared(
     arg1: Optional[PGconn_struct],
     arg2: bytes,
@@ -39,14 +41,10 @@ def PQprepare(
     arg4: int,
     arg5: Optional[Array[c_uint]],
 ) -> PGresult_struct: ...
-def PQresultErrorField(
-    arg1: Optional[PGresult_struct], arg2: int
-) -> Optional[bytes]: ...
-def PQfname(arg1: Optional[PGresult_struct], arg2: int) -> Optional[bytes]: ...
 def PQgetvalue(
     arg1: Optional[PGresult_struct], arg2: int, arg3: int
 ) -> pointer[c_char]: ...
-def PQcmdStatus(arg1: Optional[PGresult_struct]) -> Optional[bytes]: ...
+def PQcmdTuples(arg1: Optional[PGresult_struct]) -> bytes: ...
 
 # fmt: off
 # autogenerated: start
@@ -63,20 +61,19 @@ def PQreset(arg1: Optional[PGconn_struct]) -> None: ...
 def PQresetStart(arg1: Optional[PGconn_struct]) -> int: ...
 def PQresetPoll(arg1: Optional[PGconn_struct]) -> int: ...
 def PQping(arg1: bytes) -> int: ...
-def PQdb(arg1: Optional[PGconn_struct]) -> bytes: ...
-def PQuser(arg1: Optional[PGconn_struct]) -> bytes: ...
-def PQpass(arg1: Optional[PGconn_struct]) -> bytes: ...
-def PQhost(arg1: Optional[PGconn_struct]) -> bytes: ...
-def _PQhostaddr(arg1: Optional[PGconn_struct]) -> bytes: ...
-def PQport(arg1: Optional[PGconn_struct]) -> bytes: ...
-def PQtty(arg1: Optional[PGconn_struct]) -> bytes: ...
-def PQoptions(arg1: Optional[PGconn_struct]) -> bytes: ...
+def PQdb(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQuser(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQpass(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQhost(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def _PQhostaddr(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQport(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQtty(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQoptions(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
 def PQstatus(arg1: Optional[PGconn_struct]) -> int: ...
 def PQtransactionStatus(arg1: Optional[PGconn_struct]) -> int: ...
-def PQparameterStatus(arg1: Optional[PGconn_struct], arg2: bytes) -> bytes: ...
+def PQparameterStatus(arg1: Optional[PGconn_struct], arg2: bytes) -> Optional[bytes]: ...
 def PQprotocolVersion(arg1: Optional[PGconn_struct]) -> int: ...
 def PQserverVersion(arg1: Optional[PGconn_struct]) -> int: ...
-def PQerrorMessage(arg1: Optional[PGconn_struct]) -> bytes: ...
 def PQsocket(arg1: Optional[PGconn_struct]) -> int: ...
 def PQbackendPID(arg1: Optional[PGconn_struct]) -> int: ...
 def PQconnectionNeedsPassword(arg1: Optional[PGconn_struct]) -> int: ...
@@ -87,10 +84,11 @@ def PQexecParams(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: po
 def PQdescribePrepared(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ...
 def PQdescribePortal(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ...
 def PQresultStatus(arg1: Optional[PGresult_struct]) -> int: ...
-def PQresultErrorMessage(arg1: Optional[PGresult_struct]) -> bytes: ...
+def PQresultErrorField(arg1: Optional[PGresult_struct], arg2: int) -> Optional[bytes]: ...
 def PQclear(arg1: Optional[PGresult_struct]) -> None: ...
 def PQntuples(arg1: Optional[PGresult_struct]) -> int: ...
 def PQnfields(arg1: Optional[PGresult_struct]) -> int: ...
+def PQfname(arg1: Optional[PGresult_struct], arg2: int) -> Optional[bytes]: ...
 def PQftable(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
 def PQftablecol(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
 def PQfformat(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
@@ -102,7 +100,7 @@ def PQgetisnull(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: .
 def PQgetlength(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ...
 def PQnparams(arg1: Optional[PGresult_struct]) -> int: ...
 def PQparamtype(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
-def PQcmdTuples(arg1: Optional[PGresult_struct]) -> bytes: ...
+def PQcmdStatus(arg1: Optional[PGresult_struct]) -> Optional[bytes]: ...
 def PQoidValue(arg1: Optional[PGresult_struct]) -> int: ...
 def PQescapeByteaConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: pointer[c_ulong]) -> pointer[c_ubyte]: ...
 def PQescapeBytea(arg1: bytes, arg2: int, arg3: pointer[c_ulong]) -> pointer[c_ubyte]: ...
index 2fef75be0e62c7c9e9fc26311c29d619609d44ee..ce9f5e765d8a2ab2491a9a120c5adee7bfe79259 100644 (file)
@@ -10,7 +10,7 @@ implementation.
 
 from ctypes import Array, pointer, string_at
 from ctypes import c_char_p, c_int, c_size_t, c_ulong
-from typing import Any, List, Optional, Sequence
+from typing import Any, Callable, List, Optional, Sequence
 from typing import cast as t_cast
 
 from .enums import (
@@ -75,6 +75,8 @@ class PGconn:
 
     @property
     def info(self) -> List["ConninfoOption"]:
+        if not self.pgconn_ptr:
+            raise PQerror("the connection is closed")
         opts = impl.PQconninfo(self.pgconn_ptr)
         if not opts:
             raise MemoryError("couldn't allocate connection info")
@@ -104,35 +106,35 @@ class PGconn:
 
     @property
     def db(self) -> bytes:
-        return impl.PQdb(self.pgconn_ptr)
+        return self._call_bytes(impl.PQdb)
 
     @property
     def user(self) -> bytes:
-        return impl.PQuser(self.pgconn_ptr)
+        return self._call_bytes(impl.PQuser)
 
     @property
     def password(self) -> bytes:
-        return impl.PQpass(self.pgconn_ptr)
+        return self._call_bytes(impl.PQpass)
 
     @property
     def host(self) -> bytes:
-        return impl.PQhost(self.pgconn_ptr)
+        return self._call_bytes(impl.PQhost)
 
     @property
     def hostaddr(self) -> bytes:
-        return impl.PQhostaddr(self.pgconn_ptr)
+        return self._call_bytes(impl.PQhostaddr)
 
     @property
     def port(self) -> bytes:
-        return impl.PQport(self.pgconn_ptr)
+        return self._call_bytes(impl.PQport)
 
     @property
     def tty(self) -> bytes:
-        return impl.PQtty(self.pgconn_ptr)
+        return self._call_bytes(impl.PQtty)
 
     @property
     def options(self) -> bytes:
-        return impl.PQoptions(self.pgconn_ptr)
+        return self._call_bytes(impl.PQoptions)
 
     @property
     def status(self) -> ConnStatus:
@@ -145,39 +147,41 @@ class PGconn:
         return TransactionStatus(rv)
 
     def parameter_status(self, name: bytes) -> Optional[bytes]:
+        if not self.pgconn_ptr:
+            raise PQerror("the connection is closed")
         return impl.PQparameterStatus(self.pgconn_ptr, name)
 
     @property
-    def protocol_version(self) -> int:
-        return impl.PQprotocolVersion(self.pgconn_ptr)
+    def error_message(self) -> bytes:
+        return impl.PQerrorMessage(self.pgconn_ptr)
 
     @property
-    def server_version(self) -> int:
-        return impl.PQserverVersion(self.pgconn_ptr)
+    def protocol_version(self) -> int:
+        return self._call_int(impl.PQprotocolVersion)
 
     @property
-    def error_message(self) -> bytes:
-        return impl.PQerrorMessage(self.pgconn_ptr)
+    def server_version(self) -> int:
+        return self._call_int(impl.PQserverVersion)
 
     @property
     def socket(self) -> int:
-        return impl.PQsocket(self.pgconn_ptr)
+        return self._call_int(impl.PQsocket)
 
     @property
     def backend_pid(self) -> int:
-        return impl.PQbackendPID(self.pgconn_ptr)
+        return self._call_int(impl.PQbackendPID)
 
     @property
     def needs_password(self) -> bool:
-        return bool(impl.PQconnectionNeedsPassword(self.pgconn_ptr))
+        return self._call_bool(impl.PQconnectionNeedsPassword)
 
     @property
     def used_password(self) -> bool:
-        return bool(impl.PQconnectionUsedPassword(self.pgconn_ptr))
+        return self._call_bool(impl.PQconnectionUsedPassword)
 
     @property
     def ssl_in_use(self) -> bool:
-        return bool(impl.PQsslInUse(self.pgconn_ptr))
+        return self._call_bool(impl.PQsslInUse)
 
     def exec_(self, command: bytes) -> "PGresult":
         if not isinstance(command, bytes):
@@ -393,6 +397,34 @@ class PGconn:
             raise MemoryError("couldn't allocate empty PGresult")
         return PGresult(rv)
 
+    def _call_bytes(
+        self, func: Callable[[impl.PGconn_struct], Optional[bytes]]
+    ) -> bytes:
+        """
+        Call one of the pgconn libpq functions returning a bytes pointer.
+        """
+        if not self.pgconn_ptr:
+            raise PQerror("the connection is closed")
+        rv = func(self.pgconn_ptr)
+        assert rv is not None
+        return rv
+
+    def _call_int(self, func: Callable[[impl.PGconn_struct], int]) -> int:
+        """
+        Call one of the pgconn libpq functions returning an int.
+        """
+        if not self.pgconn_ptr:
+            raise PQerror("the connection is closed")
+        return func(self.pgconn_ptr)
+
+    def _call_bool(self, func: Callable[[impl.PGconn_struct], int]) -> bool:
+        """
+        Call one of the pgconn libpq functions returning a logical value.
+        """
+        if not self.pgconn_ptr:
+            raise PQerror("the connection is closed")
+        return bool(func(self.pgconn_ptr))
+
 
 class PGresult:
     __slots__ = ("pgresult_ptr",)
index 131347935e4c8e1338c096c6d69bdceb9d251105..881a29cc78b3eec5c79bc865a990ceab5399963d 100644 (file)
@@ -56,6 +56,14 @@ def test_connect_async_bad(pq, dsn):
     assert conn.status == pq.ConnStatus.BAD
 
 
+def test_finish(pgconn, pq):
+    assert pgconn.status == pq.ConnStatus.OK
+    pgconn.finish()
+    assert pgconn.status == pq.ConnStatus.BAD
+    pgconn.finish()
+    assert pgconn.status == pq.ConnStatus.BAD
+
+
 def test_info(pq, dsn, pgconn):
     info = pgconn.info
     assert len(info) > 20
@@ -69,20 +77,31 @@ def test_info(pq, dsn, pgconn):
     name = [o.val for o in parsed if o.keyword == b"dbname"][0]
     assert dbname.val == name
 
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        pgconn.info
+
 
 def test_reset(pq, pgconn):
     assert pgconn.status == pq.ConnStatus.OK
-    # TODO: break it
+    pgconn.exec_(b"select pg_terminate_backend(pg_backend_pid())")
+    assert pgconn.status == pq.ConnStatus.BAD
     pgconn.reset()
     assert pgconn.status == pq.ConnStatus.OK
 
+    # doesn't work after finish, but doesn't die either
+    pgconn.finish()
+    pgconn.reset()
+    assert pgconn.status == pq.ConnStatus.BAD
+
 
 def test_reset_async(pq, pgconn):
     assert pgconn.status == pq.ConnStatus.OK
-    # TODO: break it
+    pgconn.exec_(b"select pg_terminate_backend(pg_backend_pid())")
+    assert pgconn.status == pq.ConnStatus.BAD
     pgconn.reset_start()
     while 1:
-        rv = pgconn.connect_poll()
+        rv = pgconn.reset_poll()
         if rv == pq.PollingStatus.READING:
             select([pgconn.socket], [], [])
         elif rv == pq.PollingStatus.WRITING:
@@ -93,6 +112,12 @@ def test_reset_async(pq, pgconn):
     assert rv == pq.PollingStatus.OK
     assert pgconn.status == pq.ConnStatus.OK
 
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        pgconn.reset_start()
+
+    assert pgconn.reset_poll() == 0
+
 
 def test_ping(pq, dsn):
     rv = pq.PGconn.ping(dsn.encode("utf8"))
@@ -105,27 +130,42 @@ def test_ping(pq, dsn):
 def test_db(pgconn):
     name = [o.val for o in pgconn.info if o.keyword == b"dbname"][0]
     assert pgconn.db == name
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        pgconn.db
 
 
 def test_user(pgconn):
     user = [o.val for o in pgconn.info if o.keyword == b"user"][0]
     assert pgconn.user == user
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        pgconn.user
 
 
 def test_password(pgconn):
     # not in info
     assert isinstance(pgconn.password, bytes)
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        pgconn.password
 
 
 def test_host(pgconn):
     # might be not in info
     assert isinstance(pgconn.host, bytes)
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        pgconn.host
 
 
 @pytest.mark.libpq(">= 12")
 def test_hostaddr(pgconn):
     # not in info
     assert isinstance(pgconn.hostaddr, bytes), pgconn.hostaddr
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        pgconn.hostaddr
 
 
 @pytest.mark.libpq("< 12")
@@ -134,14 +174,30 @@ def test_hostaddr_missing(pgconn):
         pgconn.hostaddr
 
 
+def test_port(pgconn):
+    port = [o.val for o in pgconn.info if o.keyword == b"port"][0]
+    assert pgconn.port == port
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        pgconn.port
+
+
 def test_tty(pgconn):
     tty = [o.val for o in pgconn.info if o.keyword == b"tty"][0]
     assert pgconn.tty == tty
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        pgconn.tty
 
 
 def test_transaction_status(pq, pgconn):
     assert pgconn.transaction_status == pq.TransactionStatus.IDLE
-    # TODO: test other states
+    pgconn.exec_(b"begin")
+    assert pgconn.transaction_status == pq.TransactionStatus.INTRANS
+    pgconn.send_query(b"select 1")
+    assert pgconn.transaction_status == pq.TransactionStatus.ACTIVE
+    psycopg3.waiting.wait(psycopg3.Connection._exec_gen(pgconn))
+    assert pgconn.transaction_status == pq.TransactionStatus.INTRANS
     pgconn.finish()
     assert pgconn.transaction_status == pq.TransactionStatus.UNKNOWN
 
@@ -151,6 +207,9 @@ def test_parameter_status(pq, dsn, tempenv):
     pgconn = pq.PGconn.connect(dsn.encode("utf8"))
     assert pgconn.parameter_status(b"application_name") == b"psycopg3 tests"
     assert pgconn.parameter_status(b"wat") is None
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        pgconn.parameter_status(b"application_name")
 
 
 def test_encoding(pq, pgconn):
@@ -166,29 +225,48 @@ def test_encoding(pq, pgconn):
     assert res.status == pq.ExecStatus.FATAL_ERROR
     assert pgconn.parameter_status(b"client_encoding") == b"UTF8"
 
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        pgconn.parameter_status(b"client_encoding")
+
 
 def test_protocol_version(pgconn):
     assert pgconn.protocol_version == 3
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        pgconn.protocol_version
 
 
 def test_server_version(pgconn):
     assert pgconn.server_version >= 90400
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        pgconn.server_version
 
 
 def test_error_message(pq, pgconn):
+    assert pgconn.error_message == b""
     res = pgconn.exec_(b"wat")
     assert res.status == pq.ExecStatus.FATAL_ERROR
     msg = pgconn.error_message
     assert b"wat" in msg
+    pgconn.finish()
+    assert b"NULL" in pgconn.error_message  # TODO: i10n?
 
 
 def test_backend_pid(pgconn):
     assert 2 <= pgconn.backend_pid <= 65535  # Unless increased in kernel?
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        pgconn.backend_pid
 
 
 def test_needs_password(pgconn):
     # assume connection worked so an eventually needed password wasn't missing
     assert pgconn.needs_password is False
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        pgconn.needs_password
 
 
 def test_used_password(pq, pgconn, tempenv, dsn):
@@ -206,6 +284,10 @@ def test_used_password(pq, pgconn, tempenv, dsn):
     if has_password:
         assert pgconn.used_password
 
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        pgconn.used_password
+
 
 def test_ssl_in_use(pgconn):
     assert isinstance(pgconn.ssl_in_use, bool)
@@ -220,9 +302,18 @@ def test_ssl_in_use(pgconn):
             # but maybe unlikely in the tests environment?
             assert pgconn.ssl_in_use
 
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        pgconn.ssl_in_use
+
 
 def test_make_empty_result(pq, pgconn):
     pgconn.exec_(b"wat")
     res = pgconn.make_empty_result(pq.ExecStatus.FATAL_ERROR)
     assert res.status == pq.ExecStatus.FATAL_ERROR
     assert b"wat" in res.error_message
+
+    pgconn.finish()
+    res = pgconn.make_empty_result(pq.ExecStatus.FATAL_ERROR)
+    assert res.status == pq.ExecStatus.FATAL_ERROR
+    assert res.error_message == b""