From: Daniele Varrazzo Date: Sat, 11 Apr 2020 07:48:48 +0000 (+1200) Subject: More care with functions returning null in PGconn methods X-Git-Tag: 3.0.dev0~580 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=215d5566bec5289cdf8bad88787e153e76213700;p=thirdparty%2Fpsycopg.git More care with functions returning null in PGconn methods ctypes stub changed: by default the function returning bytes can return NULL. The ones which never do it are few, so these are the ones which get special-cased. --- diff --git a/psycopg3/pq/_pq_ctypes.py b/psycopg3/pq/_pq_ctypes.py index 3dc486452..0d232466d 100644 --- a/psycopg3/pq/_pq_ctypes.py +++ b/psycopg3/pq/_pq_ctypes.py @@ -458,7 +458,10 @@ def generate_stub() -> None: elif t is c_int or t is c_uint or t is c_size_t: return "int" elif t is c_char_p or t.__name__ == "LP_c_char": - return "bytes" + if narg is not None: + return "bytes" + else: + return "Optional[bytes]" elif t.__name__ in ("LP_PGconn_struct", "LP_PGresult_struct",): if narg is not None: diff --git a/psycopg3/pq/_pq_ctypes.pyi b/psycopg3/pq/_pq_ctypes.pyi index ece5b7238..a47e4530e 100644 --- a/psycopg3/pq/_pq_ctypes.pyi +++ b/psycopg3/pq/_pq_ctypes.pyi @@ -23,6 +23,8 @@ class PQconninfoOption_struct: dispsize: int def PQhostaddr(arg1: Optional[PGconn_struct]) -> bytes: ... +def PQerrorMessage(arg1: Optional[PGconn_struct]) -> bytes: ... +def PQresultErrorMessage(arg1: Optional[PGresult_struct]) -> bytes: ... def PQexecPrepared( arg1: Optional[PGconn_struct], arg2: bytes, @@ -39,14 +41,10 @@ def PQprepare( arg4: int, arg5: Optional[Array[c_uint]], ) -> PGresult_struct: ... -def PQresultErrorField( - arg1: Optional[PGresult_struct], arg2: int -) -> Optional[bytes]: ... -def PQfname(arg1: Optional[PGresult_struct], arg2: int) -> Optional[bytes]: ... def PQgetvalue( arg1: Optional[PGresult_struct], arg2: int, arg3: int ) -> pointer[c_char]: ... -def PQcmdStatus(arg1: Optional[PGresult_struct]) -> Optional[bytes]: ... +def PQcmdTuples(arg1: Optional[PGresult_struct]) -> bytes: ... # fmt: off # autogenerated: start @@ -63,20 +61,19 @@ def PQreset(arg1: Optional[PGconn_struct]) -> None: ... def PQresetStart(arg1: Optional[PGconn_struct]) -> int: ... def PQresetPoll(arg1: Optional[PGconn_struct]) -> int: ... def PQping(arg1: bytes) -> int: ... -def PQdb(arg1: Optional[PGconn_struct]) -> bytes: ... -def PQuser(arg1: Optional[PGconn_struct]) -> bytes: ... -def PQpass(arg1: Optional[PGconn_struct]) -> bytes: ... -def PQhost(arg1: Optional[PGconn_struct]) -> bytes: ... -def _PQhostaddr(arg1: Optional[PGconn_struct]) -> bytes: ... -def PQport(arg1: Optional[PGconn_struct]) -> bytes: ... -def PQtty(arg1: Optional[PGconn_struct]) -> bytes: ... -def PQoptions(arg1: Optional[PGconn_struct]) -> bytes: ... +def PQdb(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQuser(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQpass(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQhost(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def _PQhostaddr(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQport(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQtty(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQoptions(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... def PQstatus(arg1: Optional[PGconn_struct]) -> int: ... def PQtransactionStatus(arg1: Optional[PGconn_struct]) -> int: ... -def PQparameterStatus(arg1: Optional[PGconn_struct], arg2: bytes) -> bytes: ... +def PQparameterStatus(arg1: Optional[PGconn_struct], arg2: bytes) -> Optional[bytes]: ... def PQprotocolVersion(arg1: Optional[PGconn_struct]) -> int: ... def PQserverVersion(arg1: Optional[PGconn_struct]) -> int: ... -def PQerrorMessage(arg1: Optional[PGconn_struct]) -> bytes: ... def PQsocket(arg1: Optional[PGconn_struct]) -> int: ... def PQbackendPID(arg1: Optional[PGconn_struct]) -> int: ... def PQconnectionNeedsPassword(arg1: Optional[PGconn_struct]) -> int: ... @@ -87,10 +84,11 @@ def PQexecParams(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: po def PQdescribePrepared(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ... def PQdescribePortal(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ... def PQresultStatus(arg1: Optional[PGresult_struct]) -> int: ... -def PQresultErrorMessage(arg1: Optional[PGresult_struct]) -> bytes: ... +def PQresultErrorField(arg1: Optional[PGresult_struct], arg2: int) -> Optional[bytes]: ... def PQclear(arg1: Optional[PGresult_struct]) -> None: ... def PQntuples(arg1: Optional[PGresult_struct]) -> int: ... def PQnfields(arg1: Optional[PGresult_struct]) -> int: ... +def PQfname(arg1: Optional[PGresult_struct], arg2: int) -> Optional[bytes]: ... def PQftable(arg1: Optional[PGresult_struct], arg2: int) -> int: ... def PQftablecol(arg1: Optional[PGresult_struct], arg2: int) -> int: ... def PQfformat(arg1: Optional[PGresult_struct], arg2: int) -> int: ... @@ -102,7 +100,7 @@ def PQgetisnull(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: . def PQgetlength(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ... def PQnparams(arg1: Optional[PGresult_struct]) -> int: ... def PQparamtype(arg1: Optional[PGresult_struct], arg2: int) -> int: ... -def PQcmdTuples(arg1: Optional[PGresult_struct]) -> bytes: ... +def PQcmdStatus(arg1: Optional[PGresult_struct]) -> Optional[bytes]: ... def PQoidValue(arg1: Optional[PGresult_struct]) -> int: ... def PQescapeByteaConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: pointer[c_ulong]) -> pointer[c_ubyte]: ... def PQescapeBytea(arg1: bytes, arg2: int, arg3: pointer[c_ulong]) -> pointer[c_ubyte]: ... diff --git a/psycopg3/pq/pq_ctypes.py b/psycopg3/pq/pq_ctypes.py index 2fef75be0..ce9f5e765 100644 --- a/psycopg3/pq/pq_ctypes.py +++ b/psycopg3/pq/pq_ctypes.py @@ -10,7 +10,7 @@ implementation. from ctypes import Array, pointer, string_at from ctypes import c_char_p, c_int, c_size_t, c_ulong -from typing import Any, List, Optional, Sequence +from typing import Any, Callable, List, Optional, Sequence from typing import cast as t_cast from .enums import ( @@ -75,6 +75,8 @@ class PGconn: @property def info(self) -> List["ConninfoOption"]: + if not self.pgconn_ptr: + raise PQerror("the connection is closed") opts = impl.PQconninfo(self.pgconn_ptr) if not opts: raise MemoryError("couldn't allocate connection info") @@ -104,35 +106,35 @@ class PGconn: @property def db(self) -> bytes: - return impl.PQdb(self.pgconn_ptr) + return self._call_bytes(impl.PQdb) @property def user(self) -> bytes: - return impl.PQuser(self.pgconn_ptr) + return self._call_bytes(impl.PQuser) @property def password(self) -> bytes: - return impl.PQpass(self.pgconn_ptr) + return self._call_bytes(impl.PQpass) @property def host(self) -> bytes: - return impl.PQhost(self.pgconn_ptr) + return self._call_bytes(impl.PQhost) @property def hostaddr(self) -> bytes: - return impl.PQhostaddr(self.pgconn_ptr) + return self._call_bytes(impl.PQhostaddr) @property def port(self) -> bytes: - return impl.PQport(self.pgconn_ptr) + return self._call_bytes(impl.PQport) @property def tty(self) -> bytes: - return impl.PQtty(self.pgconn_ptr) + return self._call_bytes(impl.PQtty) @property def options(self) -> bytes: - return impl.PQoptions(self.pgconn_ptr) + return self._call_bytes(impl.PQoptions) @property def status(self) -> ConnStatus: @@ -145,39 +147,41 @@ class PGconn: return TransactionStatus(rv) def parameter_status(self, name: bytes) -> Optional[bytes]: + if not self.pgconn_ptr: + raise PQerror("the connection is closed") return impl.PQparameterStatus(self.pgconn_ptr, name) @property - def protocol_version(self) -> int: - return impl.PQprotocolVersion(self.pgconn_ptr) + def error_message(self) -> bytes: + return impl.PQerrorMessage(self.pgconn_ptr) @property - def server_version(self) -> int: - return impl.PQserverVersion(self.pgconn_ptr) + def protocol_version(self) -> int: + return self._call_int(impl.PQprotocolVersion) @property - def error_message(self) -> bytes: - return impl.PQerrorMessage(self.pgconn_ptr) + def server_version(self) -> int: + return self._call_int(impl.PQserverVersion) @property def socket(self) -> int: - return impl.PQsocket(self.pgconn_ptr) + return self._call_int(impl.PQsocket) @property def backend_pid(self) -> int: - return impl.PQbackendPID(self.pgconn_ptr) + return self._call_int(impl.PQbackendPID) @property def needs_password(self) -> bool: - return bool(impl.PQconnectionNeedsPassword(self.pgconn_ptr)) + return self._call_bool(impl.PQconnectionNeedsPassword) @property def used_password(self) -> bool: - return bool(impl.PQconnectionUsedPassword(self.pgconn_ptr)) + return self._call_bool(impl.PQconnectionUsedPassword) @property def ssl_in_use(self) -> bool: - return bool(impl.PQsslInUse(self.pgconn_ptr)) + return self._call_bool(impl.PQsslInUse) def exec_(self, command: bytes) -> "PGresult": if not isinstance(command, bytes): @@ -393,6 +397,34 @@ class PGconn: raise MemoryError("couldn't allocate empty PGresult") return PGresult(rv) + def _call_bytes( + self, func: Callable[[impl.PGconn_struct], Optional[bytes]] + ) -> bytes: + """ + Call one of the pgconn libpq functions returning a bytes pointer. + """ + if not self.pgconn_ptr: + raise PQerror("the connection is closed") + rv = func(self.pgconn_ptr) + assert rv is not None + return rv + + def _call_int(self, func: Callable[[impl.PGconn_struct], int]) -> int: + """ + Call one of the pgconn libpq functions returning an int. + """ + if not self.pgconn_ptr: + raise PQerror("the connection is closed") + return func(self.pgconn_ptr) + + def _call_bool(self, func: Callable[[impl.PGconn_struct], int]) -> bool: + """ + Call one of the pgconn libpq functions returning a logical value. + """ + if not self.pgconn_ptr: + raise PQerror("the connection is closed") + return bool(func(self.pgconn_ptr)) + class PGresult: __slots__ = ("pgresult_ptr",) diff --git a/tests/pq/test_pgconn.py b/tests/pq/test_pgconn.py index 131347935..881a29cc7 100644 --- a/tests/pq/test_pgconn.py +++ b/tests/pq/test_pgconn.py @@ -56,6 +56,14 @@ def test_connect_async_bad(pq, dsn): assert conn.status == pq.ConnStatus.BAD +def test_finish(pgconn, pq): + assert pgconn.status == pq.ConnStatus.OK + pgconn.finish() + assert pgconn.status == pq.ConnStatus.BAD + pgconn.finish() + assert pgconn.status == pq.ConnStatus.BAD + + def test_info(pq, dsn, pgconn): info = pgconn.info assert len(info) > 20 @@ -69,20 +77,31 @@ def test_info(pq, dsn, pgconn): name = [o.val for o in parsed if o.keyword == b"dbname"][0] assert dbname.val == name + pgconn.finish() + with pytest.raises(psycopg3.OperationalError): + pgconn.info + def test_reset(pq, pgconn): assert pgconn.status == pq.ConnStatus.OK - # TODO: break it + pgconn.exec_(b"select pg_terminate_backend(pg_backend_pid())") + assert pgconn.status == pq.ConnStatus.BAD pgconn.reset() assert pgconn.status == pq.ConnStatus.OK + # doesn't work after finish, but doesn't die either + pgconn.finish() + pgconn.reset() + assert pgconn.status == pq.ConnStatus.BAD + def test_reset_async(pq, pgconn): assert pgconn.status == pq.ConnStatus.OK - # TODO: break it + pgconn.exec_(b"select pg_terminate_backend(pg_backend_pid())") + assert pgconn.status == pq.ConnStatus.BAD pgconn.reset_start() while 1: - rv = pgconn.connect_poll() + rv = pgconn.reset_poll() if rv == pq.PollingStatus.READING: select([pgconn.socket], [], []) elif rv == pq.PollingStatus.WRITING: @@ -93,6 +112,12 @@ def test_reset_async(pq, pgconn): assert rv == pq.PollingStatus.OK assert pgconn.status == pq.ConnStatus.OK + pgconn.finish() + with pytest.raises(psycopg3.OperationalError): + pgconn.reset_start() + + assert pgconn.reset_poll() == 0 + def test_ping(pq, dsn): rv = pq.PGconn.ping(dsn.encode("utf8")) @@ -105,27 +130,42 @@ def test_ping(pq, dsn): def test_db(pgconn): name = [o.val for o in pgconn.info if o.keyword == b"dbname"][0] assert pgconn.db == name + pgconn.finish() + with pytest.raises(psycopg3.OperationalError): + pgconn.db def test_user(pgconn): user = [o.val for o in pgconn.info if o.keyword == b"user"][0] assert pgconn.user == user + pgconn.finish() + with pytest.raises(psycopg3.OperationalError): + pgconn.user def test_password(pgconn): # not in info assert isinstance(pgconn.password, bytes) + pgconn.finish() + with pytest.raises(psycopg3.OperationalError): + pgconn.password def test_host(pgconn): # might be not in info assert isinstance(pgconn.host, bytes) + pgconn.finish() + with pytest.raises(psycopg3.OperationalError): + pgconn.host @pytest.mark.libpq(">= 12") def test_hostaddr(pgconn): # not in info assert isinstance(pgconn.hostaddr, bytes), pgconn.hostaddr + pgconn.finish() + with pytest.raises(psycopg3.OperationalError): + pgconn.hostaddr @pytest.mark.libpq("< 12") @@ -134,14 +174,30 @@ def test_hostaddr_missing(pgconn): pgconn.hostaddr +def test_port(pgconn): + port = [o.val for o in pgconn.info if o.keyword == b"port"][0] + assert pgconn.port == port + pgconn.finish() + with pytest.raises(psycopg3.OperationalError): + pgconn.port + + def test_tty(pgconn): tty = [o.val for o in pgconn.info if o.keyword == b"tty"][0] assert pgconn.tty == tty + pgconn.finish() + with pytest.raises(psycopg3.OperationalError): + pgconn.tty def test_transaction_status(pq, pgconn): assert pgconn.transaction_status == pq.TransactionStatus.IDLE - # TODO: test other states + pgconn.exec_(b"begin") + assert pgconn.transaction_status == pq.TransactionStatus.INTRANS + pgconn.send_query(b"select 1") + assert pgconn.transaction_status == pq.TransactionStatus.ACTIVE + psycopg3.waiting.wait(psycopg3.Connection._exec_gen(pgconn)) + assert pgconn.transaction_status == pq.TransactionStatus.INTRANS pgconn.finish() assert pgconn.transaction_status == pq.TransactionStatus.UNKNOWN @@ -151,6 +207,9 @@ def test_parameter_status(pq, dsn, tempenv): pgconn = pq.PGconn.connect(dsn.encode("utf8")) assert pgconn.parameter_status(b"application_name") == b"psycopg3 tests" assert pgconn.parameter_status(b"wat") is None + pgconn.finish() + with pytest.raises(psycopg3.OperationalError): + pgconn.parameter_status(b"application_name") def test_encoding(pq, pgconn): @@ -166,29 +225,48 @@ def test_encoding(pq, pgconn): assert res.status == pq.ExecStatus.FATAL_ERROR assert pgconn.parameter_status(b"client_encoding") == b"UTF8" + pgconn.finish() + with pytest.raises(psycopg3.OperationalError): + pgconn.parameter_status(b"client_encoding") + def test_protocol_version(pgconn): assert pgconn.protocol_version == 3 + pgconn.finish() + with pytest.raises(psycopg3.OperationalError): + pgconn.protocol_version def test_server_version(pgconn): assert pgconn.server_version >= 90400 + pgconn.finish() + with pytest.raises(psycopg3.OperationalError): + pgconn.server_version def test_error_message(pq, pgconn): + assert pgconn.error_message == b"" res = pgconn.exec_(b"wat") assert res.status == pq.ExecStatus.FATAL_ERROR msg = pgconn.error_message assert b"wat" in msg + pgconn.finish() + assert b"NULL" in pgconn.error_message # TODO: i10n? def test_backend_pid(pgconn): assert 2 <= pgconn.backend_pid <= 65535 # Unless increased in kernel? + pgconn.finish() + with pytest.raises(psycopg3.OperationalError): + pgconn.backend_pid def test_needs_password(pgconn): # assume connection worked so an eventually needed password wasn't missing assert pgconn.needs_password is False + pgconn.finish() + with pytest.raises(psycopg3.OperationalError): + pgconn.needs_password def test_used_password(pq, pgconn, tempenv, dsn): @@ -206,6 +284,10 @@ def test_used_password(pq, pgconn, tempenv, dsn): if has_password: assert pgconn.used_password + pgconn.finish() + with pytest.raises(psycopg3.OperationalError): + pgconn.used_password + def test_ssl_in_use(pgconn): assert isinstance(pgconn.ssl_in_use, bool) @@ -220,9 +302,18 @@ def test_ssl_in_use(pgconn): # but maybe unlikely in the tests environment? assert pgconn.ssl_in_use + pgconn.finish() + with pytest.raises(psycopg3.OperationalError): + pgconn.ssl_in_use + def test_make_empty_result(pq, pgconn): pgconn.exec_(b"wat") res = pgconn.make_empty_result(pq.ExecStatus.FATAL_ERROR) assert res.status == pq.ExecStatus.FATAL_ERROR assert b"wat" in res.error_message + + pgconn.finish() + res = pgconn.make_empty_result(pq.ExecStatus.FATAL_ERROR) + assert res.status == pq.ExecStatus.FATAL_ERROR + assert res.error_message == b""