From: Lysandros Nikolaou Date: Fri, 14 Nov 2025 16:31:11 +0000 (+0100) Subject: Use critical section to protect pgconn ptr X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=33fb9fc4908303a9c9a83fc476fd553ec7ee4511;p=thirdparty%2Fpsycopg.git Use critical section to protect pgconn ptr This PR adds critical sections to the PGconn class so that race condition between calling into the libpq bindings and calling `close` which sets the `_pgconn_ptr` to NULL are eliminated. Also add one high-level and one low-level tests that trigger TSAN warnings without this PR. --- diff --git a/psycopg_c/psycopg_c/pq.pxd b/psycopg_c/psycopg_c/pq.pxd index 8a2dbd634..1efe3bb01 100644 --- a/psycopg_c/psycopg_c/pq.pxd +++ b/psycopg_c/psycopg_c/pq.pxd @@ -13,8 +13,12 @@ cdef extern from * nogil: from psycopg_c.pq cimport libpq -ctypedef char *(*conn_bytes_f) (const libpq.PGconn *) -ctypedef int(*conn_int_f) (const libpq.PGconn *) +ctypedef char *(*conn_bytes_f) (const libpq.PGconn *) noexcept nogil +ctypedef int (*conn_int_f) (const libpq.PGconn *) noexcept nogil +ctypedef void *(*conn_f_with_param) (const libpq.PGconn *, const char *) noexcept nogil +ctypedef int (*conn_int_f_with_param) ( + const libpq.PGconn *, const char * +) noexcept nogil cdef class PGconn: diff --git a/psycopg_c/psycopg_c/pq/pgconn.pyx b/psycopg_c/psycopg_c/pq/pgconn.pyx index 759054937..68e1e6dc0 100644 --- a/psycopg_c/psycopg_c/pq/pgconn.pyx +++ b/psycopg_c/psycopg_c/pq/pgconn.pyx @@ -17,6 +17,7 @@ cdef extern from * nogil: """ pid_t getpid() +cimport cython from libc.stdio cimport fdopen from cpython.mem cimport PyMem_Free, PyMem_Malloc from cpython.bytes cimport PyBytes_AsString @@ -86,24 +87,32 @@ cdef class PGconn: return PGconn._from_ptr(pgconn) def connect_poll(self) -> int: - return _call_int(self, libpq.PQconnectPoll) + return _call_libpq_int(self, libpq.PQconnectPoll) def finish(self) -> None: - if self._pgconn_ptr is not NULL: - libpq.PQfinish(self._pgconn_ptr) - self._pgconn_ptr = NULL + with cython.critical_section(self): + if self._pgconn_ptr is not NULL: + libpq.PQfinish(self._pgconn_ptr) + self._pgconn_ptr = NULL @property def pgconn_ptr(self) -> int | None: - if self._pgconn_ptr: - return self._pgconn_ptr - else: - return None + cdef long long ptr = -1 + with cython.critical_section(self): + if self._pgconn_ptr: + ptr = self._pgconn_ptr + return ptr if ptr != -1 else None @property def info(self) -> list[ConninfoOption]: - _ensure_pgconn(self) - cdef libpq.PQconninfoOption *opts = libpq.PQconninfo(self._pgconn_ptr) + cdef libpq.PQconninfoOption *opts + cdef libpq.PGconn *pgconn_ptr + with cython.critical_section(self): + pgconn_ptr = self._pgconn_ptr + if pgconn_ptr is not NULL: + opts = libpq.PQconninfo(pgconn_ptr) + if pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") if opts is NULL: raise MemoryError("couldn't allocate connection info") rv = _options_from_array(opts) @@ -111,15 +120,23 @@ cdef class PGconn: return rv def reset(self) -> None: - _ensure_pgconn(self) - libpq.PQreset(self._pgconn_ptr) + cdef libpq.PGconn *pgconn_ptr + with cython.critical_section(self): + pgconn_ptr = self._pgconn_ptr + if pgconn_ptr is not NULL: + libpq.PQreset(pgconn_ptr) + if pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") def reset_start(self) -> None: - if not libpq.PQresetStart(self._pgconn_ptr): + cdef int rv + with cython.critical_section(self): + rv = libpq.PQresetStart(self._pgconn_ptr) + if not rv: raise e.OperationalError("couldn't reset connection") def reset_poll(self) -> int: - return _call_int(self, libpq.PQresetPoll) + return _call_libpq_int(self, libpq.PQresetPoll) @classmethod def ping(self, const char *conninfo) -> int: @@ -127,51 +144,63 @@ cdef class PGconn: @property def db(self) -> bytes: - return _call_bytes(self, libpq.PQdb) + return _call_libpq_bytes(self, libpq.PQdb) @property def user(self) -> bytes: - return _call_bytes(self, libpq.PQuser) + return _call_libpq_bytes(self, libpq.PQuser) @property def password(self) -> bytes: - return _call_bytes(self, libpq.PQpass) + return _call_libpq_bytes(self, libpq.PQpass) @property def host(self) -> bytes: - return _call_bytes(self, libpq.PQhost) + return _call_libpq_bytes(self, libpq.PQhost) @property def hostaddr(self) -> bytes: _check_supported("PQhostaddr", 120000) - _ensure_pgconn(self) - cdef char *rv = libpq.PQhostaddr(self._pgconn_ptr) + cdef char *rv = NULL + cdef libpq.PGconn *pgconn_ptr + with cython.critical_section(self): + pgconn_ptr = self._pgconn_ptr + if pgconn_ptr is not NULL: + rv = libpq.PQhostaddr(pgconn_ptr) + if pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") assert rv is not NULL return rv @property def port(self) -> bytes: - return _call_bytes(self, libpq.PQport) + return _call_libpq_bytes(self, libpq.PQport) @property def tty(self) -> bytes: - return _call_bytes(self, libpq.PQtty) + return _call_libpq_bytes(self, libpq.PQtty) @property def options(self) -> bytes: - return _call_bytes(self, libpq.PQoptions) + return _call_libpq_bytes(self, libpq.PQoptions) @property def status(self) -> int: - return libpq.PQstatus(self._pgconn_ptr) + cdef libpq.ConnStatusType rv + with cython.critical_section(self): + rv = libpq.PQstatus(self._pgconn_ptr) + return rv @property def transaction_status(self) -> int: - return libpq.PQtransactionStatus(self._pgconn_ptr) + cdef libpq.PGTransactionStatusType rv + with cython.critical_section(self): + rv = libpq.PQtransactionStatus(self._pgconn_ptr) + return rv def parameter_status(self, const char *name) -> bytes | None: - _ensure_pgconn(self) - cdef const char *rv = libpq.PQparameterStatus(self._pgconn_ptr, name) + cdef const char *rv = _call_libpq_with_param( + self, libpq.PQparameterStatus, name) if rv is not NULL: return rv else: @@ -179,7 +208,10 @@ cdef class PGconn: @property def error_message(self) -> bytes: - return libpq.PQerrorMessage(self._pgconn_ptr) + cdef char *rv + with cython.critical_section(self): + rv = libpq.PQerrorMessage(self._pgconn_ptr) + return rv def get_error_message(self, encoding: str = "") -> str: return _clean_error_message(self.error_message, encoding or self._encoding) @@ -187,72 +219,101 @@ cdef class PGconn: @property def _encoding(self) -> str: cdef const char *pgenc - if libpq.PQstatus(self._pgconn_ptr) == libpq.CONNECTION_OK: - pgenc = libpq.PQparameterStatus(self._pgconn_ptr, b"client_encoding") - if pgenc is NULL: - pgenc = b"UTF8" + cdef int status + with cython.critical_section(self): + status = libpq.PQstatus(self._pgconn_ptr) + if status == libpq.CONNECTION_OK: + pgenc = libpq.PQparameterStatus(self._pgconn_ptr, b"client_encoding") + if pgenc is NULL: + pgenc = b"UTF8" + if status == libpq.CONNECTION_OK: return pg2pyenc(pgenc) else: return "utf-8" @property def protocol_version(self) -> int: - return _call_int(self, libpq.PQprotocolVersion) + return _call_libpq_int(self, libpq.PQprotocolVersion) @property def full_protocol_version(self) -> int: _check_supported("PQfullProtocolVersion", 180000) - _ensure_pgconn(self) - return libpq.PQfullProtocolVersion(self._pgconn_ptr) + cdef int rv + cdef libpq.PGconn *pgconn_ptr + with cython.critical_section(self): + pgconn_ptr = self._pgconn_ptr + if pgconn_ptr is not NULL: + rv = libpq.PQfullProtocolVersion(pgconn_ptr) + if pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") + return rv @property def server_version(self) -> int: - return _call_int(self, libpq.PQserverVersion) + return _call_libpq_int(self, libpq.PQserverVersion) @property def socket(self) -> int: - rv = _call_int(self, libpq.PQsocket) + cdef int rv = _call_libpq_int(self, libpq.PQsocket) if rv == -1: raise e.OperationalError("the connection is lost") return rv @property def backend_pid(self) -> int: - return _call_int(self, libpq.PQbackendPID) + return _call_libpq_int(self, libpq.PQbackendPID) @property def needs_password(self) -> bool: - return bool(libpq.PQconnectionNeedsPassword(self._pgconn_ptr)) + cdef int rv + with cython.critical_section(self): + rv = libpq.PQconnectionNeedsPassword(self._pgconn_ptr) + return bool(rv) @property def used_password(self) -> bool: - return bool(libpq.PQconnectionUsedPassword(self._pgconn_ptr)) + cdef int rv + with cython.critical_section(self): + rv = libpq.PQconnectionUsedPassword(self._pgconn_ptr) + return bool(rv) @property def used_gssapi(self) -> bool: _check_supported("PQconnectionUsedGSSAPI", 160000) - return bool(libpq.PQconnectionUsedGSSAPI(self._pgconn_ptr)) + cdef int rv + with cython.critical_section(self): + rv = libpq.PQconnectionUsedGSSAPI(self._pgconn_ptr) + return bool(rv) @property def ssl_in_use(self) -> bool: - return bool(_call_int(self, libpq.PQsslInUse)) + return bool(_call_libpq_int(self, libpq.PQsslInUse)) def exec_(self, const char *command) -> PGresult: - _ensure_pgconn(self) cdef libpq.PGresult *pgresult - with nogil: - pgresult = libpq.PQexec(self._pgconn_ptr, command) + cdef libpq.PGconn *pgconn_ptr + with cython.critical_section(self): + with nogil: + pgconn_ptr = self._pgconn_ptr + if pgconn_ptr is not NULL: + pgresult = libpq.PQexec(pgconn_ptr, command) + if pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") if pgresult is NULL: raise e.OperationalError( f"executing query failed: {self.get_error_message()}") - return PGresult._from_ptr(pgresult) def send_query(self, const char *command) -> None: - _ensure_pgconn(self) cdef int rv - with nogil: - rv = libpq.PQsendQuery(self._pgconn_ptr, command) + cdef libpq.PGconn *pgconn_ptr + with cython.critical_section(self): + with nogil: + pgconn_ptr = self._pgconn_ptr + if pgconn_ptr is not NULL: + rv = libpq.PQsendQuery(pgconn_ptr, command) + if pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") if not rv: raise e.OperationalError( f"sending query failed: {self.get_error_message()}") @@ -265,8 +326,6 @@ cdef class PGconn: param_formats: Sequence[int] | None = None, int result_format = PqFormat.TEXT, ) -> PGresult: - _ensure_pgconn(self) - cdef Py_ssize_t cnparams cdef libpq.Oid *ctypes cdef char *const *cvalues @@ -275,12 +334,18 @@ cdef class PGconn: cnparams, ctypes, cvalues, clengths, cformats = _query_params_args( param_values, param_types, param_formats) + cdef libpq.PGconn *pgconn_ptr cdef libpq.PGresult *pgresult - with nogil: - pgresult = libpq.PQexecParams( - self._pgconn_ptr, command, cnparams, ctypes, - cvalues, clengths, cformats, result_format) + with cython.critical_section(self): + with nogil: + pgconn_ptr = self._pgconn_ptr + if pgconn_ptr is not NULL: + pgresult = libpq.PQexecParams( + pgconn_ptr, command, cnparams, ctypes, + cvalues, clengths, cformats, result_format) _clear_query_params(ctypes, cvalues, clengths, cformats) + if pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") if pgresult is NULL: raise e.OperationalError( f"executing query failed: {self.get_error_message()}") @@ -294,8 +359,6 @@ cdef class PGconn: param_formats: Sequence[int] | None = None, int result_format = PqFormat.TEXT, ) -> None: - _ensure_pgconn(self) - cdef Py_ssize_t cnparams cdef libpq.Oid *ctypes cdef char *const *cvalues @@ -304,12 +367,18 @@ cdef class PGconn: cnparams, ctypes, cvalues, clengths, cformats = _query_params_args( param_values, param_types, param_formats) + cdef libpq.PGconn *pgconn_ptr cdef int rv - with nogil: - rv = libpq.PQsendQueryParams( - self._pgconn_ptr, command, cnparams, ctypes, - cvalues, clengths, cformats, result_format) + with cython.critical_section(self): + with nogil: + pgconn_ptr = self._pgconn_ptr + if pgconn_ptr is not NULL: + rv = libpq.PQsendQueryParams( + pgconn_ptr, command, cnparams, ctypes, + cvalues, clengths, cformats, result_format) _clear_query_params(ctypes, cvalues, clengths, cformats) + if pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") if not rv: raise e.OperationalError( f"sending query and params failed: {self.get_error_message()}" @@ -321,8 +390,6 @@ cdef class PGconn: const char *command, param_types: Sequence[int] | None = None, ) -> None: - _ensure_pgconn(self) - cdef int i cdef types_fast cdef Py_ssize_t nparams = 0 @@ -336,12 +403,18 @@ cdef class PGconn: for i in range(nparams): atypes[i] = PySequence_Fast_GET_ITEM(types_fast, i) + cdef libpq.PGconn *pgconn_ptr cdef int rv - with nogil: - rv = libpq.PQsendPrepare( - self._pgconn_ptr, name, command, nparams, atypes - ) + with cython.critical_section(self): + with nogil: + pgconn_ptr = self._pgconn_ptr + if pgconn_ptr is not NULL: + rv = libpq.PQsendPrepare( + pgconn_ptr, name, command, nparams, atypes + ) PyMem_Free(atypes) + if pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") if not rv: raise e.OperationalError( f"sending query and params failed: {self.get_error_message()}" @@ -354,8 +427,6 @@ cdef class PGconn: param_formats: Sequence[int] | None = None, int result_format = PqFormat.TEXT, ) -> None: - _ensure_pgconn(self) - cdef Py_ssize_t cnparams cdef libpq.Oid *ctypes cdef char *const *cvalues @@ -364,12 +435,18 @@ cdef class PGconn: cnparams, ctypes, cvalues, clengths, cformats = _query_params_args( param_values, None, param_formats) + cdef libpq.PGconn *pgconn_ptr cdef int rv - with nogil: - rv = libpq.PQsendQueryPrepared( - self._pgconn_ptr, name, cnparams, cvalues, - clengths, cformats, result_format) + with cython.critical_section(self): + with nogil: + pgconn_ptr = self._pgconn_ptr + if pgconn_ptr is not NULL: + rv = libpq.PQsendQueryPrepared( + pgconn_ptr, name, cnparams, cvalues, + clengths, cformats, result_format) _clear_query_params(ctypes, cvalues, clengths, cformats) + if pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") if not rv: raise e.OperationalError( f"sending prepared query failed: {self.get_error_message()}" @@ -381,8 +458,6 @@ cdef class PGconn: const char *command, param_types: Sequence[int] | None = None, ) -> PGresult: - _ensure_pgconn(self) - cdef int i cdef types_fast cdef Py_ssize_t nparams = 0 @@ -396,11 +471,17 @@ cdef class PGconn: for i in range(nparams): atypes[i] = PySequence_Fast_GET_ITEM(types_fast, i) + cdef libpq.PGconn *pgconn_ptr cdef libpq.PGresult *rv - with nogil: - rv = libpq.PQprepare( - self._pgconn_ptr, name, command, nparams, atypes) + with cython.critical_section(self): + with nogil: + pgconn_ptr = self._pgconn_ptr + if pgconn_ptr is not NULL: + rv = libpq.PQprepare( + pgconn_ptr, name, command, nparams, atypes) PyMem_Free(atypes) + if pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") if rv is NULL: raise e.OperationalError( f"preparing query failed: {self.get_error_message()}") @@ -413,8 +494,6 @@ cdef class PGconn: param_formats: Sequence[int] | None = None, int result_format = PqFormat.TEXT, ) -> PGresult: - _ensure_pgconn(self) - cdef Py_ssize_t cnparams cdef libpq.Oid *ctypes cdef char *const *cvalues @@ -423,14 +502,19 @@ cdef class PGconn: cnparams, ctypes, cvalues, clengths, cformats = _query_params_args( param_values, None, param_formats) + cdef libpq.PGconn *pgconn_ptr cdef libpq.PGresult *rv - with nogil: - rv = libpq.PQexecPrepared( - self._pgconn_ptr, name, cnparams, - cvalues, - clengths, cformats, result_format) - + with cython.critical_section(self): + with nogil: + pgconn_ptr = self._pgconn_ptr + if pgconn_ptr is not NULL: + rv = libpq.PQexecPrepared( + pgconn_ptr, name, cnparams, + cvalues, + clengths, cformats, result_format) _clear_query_params(ctypes, cvalues, clengths, cformats) + if pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") if rv is NULL: raise e.OperationalError( f"executing prepared query failed: {self.get_error_message()}" @@ -438,8 +522,8 @@ cdef class PGconn: return PGresult._from_ptr(rv) def describe_prepared(self, const char *name) -> PGresult: - _ensure_pgconn(self) - cdef libpq.PGresult *rv = libpq.PQdescribePrepared(self._pgconn_ptr, name) + cdef libpq.PGresult *rv = _call_libpq_with_param( + self, libpq.PQdescribePrepared, name) if rv is NULL: raise e.OperationalError( f"describe prepared failed: {self.get_error_message()}" @@ -447,16 +531,16 @@ cdef class PGconn: return PGresult._from_ptr(rv) def send_describe_prepared(self, const char *name) -> None: - _ensure_pgconn(self) - cdef int rv = libpq.PQsendDescribePrepared(self._pgconn_ptr, name) + cdef int rv = _call_libpq_int_with_param( + self, libpq.PQsendDescribePrepared, name) if not rv: raise e.OperationalError( f"sending describe prepared failed: {self.get_error_message()}" ) def describe_portal(self, const char *name) -> PGresult: - _ensure_pgconn(self) - cdef libpq.PGresult *rv = libpq.PQdescribePortal(self._pgconn_ptr, name) + cdef libpq.PGresult *rv = _call_libpq_with_param( + self, libpq.PQdescribePortal, name) if rv is NULL: raise e.OperationalError( f"describe prepared failed: {self.get_error_message()}" @@ -464,8 +548,8 @@ cdef class PGconn: return PGresult._from_ptr(rv) def send_describe_portal(self, const char *name) -> None: - _ensure_pgconn(self) - cdef int rv = libpq.PQsendDescribePortal(self._pgconn_ptr, name) + cdef int rv = _call_libpq_int_with_param( + self, libpq.PQsendDescribePortal, name) if not rv: raise e.OperationalError( f"sending describe prepared failed: {self.get_error_message()}" @@ -473,8 +557,14 @@ cdef class PGconn: def close_prepared(self, const char *name) -> PGresult: _check_supported("PQclosePrepared", 170000) - _ensure_pgconn(self) - cdef libpq.PGresult *rv = libpq.PQclosePrepared(self._pgconn_ptr, name) + cdef libpq.PGresult *rv + cdef libpq.PGconn *pgconn_ptr + with cython.critical_section(self): + pgconn_ptr = self._pgconn_ptr + if pgconn_ptr is not NULL: + rv = libpq.PQclosePrepared(pgconn_ptr, name) + if pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") if rv is NULL: raise e.OperationalError( f"close prepared failed: {self.get_error_message()}" @@ -483,8 +573,14 @@ cdef class PGconn: def send_close_prepared(self, const char *name) -> None: _check_supported("PQsendClosePrepared", 170000) - _ensure_pgconn(self) - cdef int rv = libpq.PQsendClosePrepared(self._pgconn_ptr, name) + cdef int rv + cdef libpq.PGconn *pgconn_ptr + with cython.critical_section(self): + pgconn_ptr = self._pgconn_ptr + if pgconn_ptr is not NULL: + rv = libpq.PQsendClosePrepared(pgconn_ptr, name) + if pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") if not rv: raise e.OperationalError( f"sending close prepared failed: {self.get_error_message()}" @@ -492,8 +588,14 @@ cdef class PGconn: def close_portal(self, const char *name) -> PGresult: _check_supported("PQclosePortal", 170000) - _ensure_pgconn(self) - cdef libpq.PGresult *rv = libpq.PQclosePortal(self._pgconn_ptr, name) + cdef libpq.PGresult *rv + cdef libpq.PGconn *pgconn_ptr + with cython.critical_section(self): + pgconn_ptr = self._pgconn_ptr + if pgconn_ptr is not NULL: + rv = libpq.PQclosePortal(pgconn_ptr, name) + if pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") if rv is NULL: raise e.OperationalError( f"close prepared failed: {self.get_error_message()}" @@ -502,73 +604,104 @@ cdef class PGconn: def send_close_portal(self, const char *name) -> None: _check_supported("PQsendClosePortal", 170000) - _ensure_pgconn(self) - cdef int rv = libpq.PQsendClosePortal(self._pgconn_ptr, name) + cdef int rv + cdef libpq.PGconn *pgconn_ptr + with cython.critical_section(self): + pgconn_ptr = self._pgconn_ptr + if pgconn_ptr is not NULL: + rv = libpq.PQsendClosePortal(pgconn_ptr, name) + if pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") if not rv: raise e.OperationalError( f"sending close prepared failed: {self.get_error_message()}" ) def get_result(self) -> "PGresult" | None: - cdef libpq.PGresult *pgresult = libpq.PQgetResult(self._pgconn_ptr) + cdef libpq.PGresult *pgresult + with cython.critical_section(self): + pgresult = libpq.PQgetResult(self._pgconn_ptr) if pgresult is NULL: return None return PGresult._from_ptr(pgresult) def consume_input(self) -> None: - if 1 != libpq.PQconsumeInput(self._pgconn_ptr): + cdef int rv + with cython.critical_section(self): + rv = libpq.PQconsumeInput(self._pgconn_ptr) + if 1 != rv: raise e.OperationalError( f"consuming input failed: {self.get_error_message()}") def is_busy(self) -> int: cdef int rv - with nogil: - rv = libpq.PQisBusy(self._pgconn_ptr) + with cython.critical_section(self): + with nogil: + rv = libpq.PQisBusy(self._pgconn_ptr) return rv @property def nonblocking(self) -> int: - return libpq.PQisnonblocking(self._pgconn_ptr) + cdef int rv + with cython.critical_section(self): + rv = libpq.PQisnonblocking(self._pgconn_ptr) + return rv @nonblocking.setter def nonblocking(self, int arg) -> None: - if 0 > libpq.PQsetnonblocking(self._pgconn_ptr, arg): + cdef int rv + with cython.critical_section(self): + rv = libpq.PQsetnonblocking(self._pgconn_ptr, arg) + if 0 > rv: raise e.OperationalError( f"setting nonblocking failed: {self.get_error_message()}") cpdef int flush(self) except -1: - if self._pgconn_ptr == NULL: - raise e.OperationalError("flushing failed: the connection is closed") - cdef int rv = libpq.PQflush(self._pgconn_ptr) + cdef int rv + with cython.critical_section(self): + if self._pgconn_ptr is NULL: + raise e.OperationalError("flushing failed: the connection is closed") + rv = libpq.PQflush(self._pgconn_ptr) if rv < 0: raise e.OperationalError(f"flushing failed: {self.get_error_message()}") return rv def set_single_row_mode(self) -> None: - if not libpq.PQsetSingleRowMode(self._pgconn_ptr): + cdef int rv + with cython.critical_section(self): + rv = libpq.PQsetSingleRowMode(self._pgconn_ptr) + if not rv: raise e.OperationalError("setting single row mode failed") def set_chunked_rows_mode(self, size: int) -> None: - if not libpq.PQsetChunkedRowsMode(self._pgconn_ptr, size): + cdef int rv + with cython.critical_section(self): + rv = libpq.PQsetChunkedRowsMode(self._pgconn_ptr, size) + if not rv: raise e.OperationalError("setting chunked rows mode failed") def cancel_conn(self) -> PGcancelConn: _check_supported("PQcancelCreate", 170000) - cdef libpq.PGcancelConn *ptr = libpq.PQcancelCreate(self._pgconn_ptr) + cdef libpq.PGcancelConn *ptr + with cython.critical_section(self): + ptr = libpq.PQcancelCreate(self._pgconn_ptr) if not ptr: raise e.OperationalError("couldn't create cancelConn object") return PGcancelConn._from_ptr(ptr) def get_cancel(self) -> PGcancel: - cdef libpq.PGcancel *ptr = libpq.PQgetCancel(self._pgconn_ptr) + cdef libpq.PGcancel *ptr + with cython.critical_section(self): + ptr = libpq.PQgetCancel(self._pgconn_ptr) if not ptr: raise e.OperationalError("couldn't create cancel object") return PGcancel._from_ptr(ptr) cpdef object notifies(self): cdef libpq.PGnotify *ptr - with nogil: - ptr = libpq.PQnotifies(self._pgconn_ptr) + with cython.critical_section(self): + with nogil: + ptr = libpq.PQnotifies(self._pgconn_ptr) if ptr: ret = PGnotify(ptr.relname, ptr.be_pid, ptr.extra) libpq.PQfreemem(ptr) @@ -582,7 +715,8 @@ cdef class PGconn: cdef Py_ssize_t length _buffer_as_string_and_size(buffer, &cbuffer, &length) - rv = libpq.PQputCopyData(self._pgconn_ptr, cbuffer, length) + with cython.critical_section(self): + rv = libpq.PQputCopyData(self._pgconn_ptr, cbuffer, length) if rv < 0: raise e.OperationalError( f"sending copy data failed: {self.get_error_message()}") @@ -593,7 +727,9 @@ cdef class PGconn: cdef const char *cerr = NULL if error is not None: cerr = PyBytes_AsString(error) - rv = libpq.PQputCopyEnd(self._pgconn_ptr, cerr) + + with cython.critical_section(self): + rv = libpq.PQputCopyEnd(self._pgconn_ptr, cerr) if rv < 0: raise e.OperationalError( f"sending copy end failed: {self.get_error_message()}") @@ -602,7 +738,9 @@ cdef class PGconn: def get_copy_data(self, int async_) -> tuple[int, memoryview]: cdef char *buffer_ptr = NULL cdef int nbytes - nbytes = libpq.PQgetCopyData(self._pgconn_ptr, &buffer_ptr, async_) + + with cython.critical_section(self): + nbytes = libpq.PQgetCopyData(self._pgconn_ptr, &buffer_ptr, async_) if nbytes == -2: raise e.OperationalError( f"receiving copy data failed: {self.get_error_message()}") @@ -617,14 +755,17 @@ cdef class PGconn: if sys.platform != "linux": raise e.NotSupportedError("currently only supported on Linux") stream = fdopen(fileno, b"w") - libpq.PQtrace(self._pgconn_ptr, stream) + with cython.critical_section(self): + libpq.PQtrace(self._pgconn_ptr, stream) def set_trace_flags(self, flags: Trace) -> None: _check_supported("PQsetTraceFlags", 140000) - libpq.PQsetTraceFlags(self._pgconn_ptr, flags) + with cython.critical_section(self): + libpq.PQsetTraceFlags(self._pgconn_ptr, flags) def untrace(self) -> None: - libpq.PQuntrace(self._pgconn_ptr) + with cython.critical_section(self): + libpq.PQuntrace(self._pgconn_ptr) def encrypt_password( self, const char *passwd, const char *user, algorithm = None @@ -635,7 +776,8 @@ cdef class PGconn: cdef const char *calgo = NULL if algorithm: calgo = algorithm - out = libpq.PQencryptPasswordConn(self._pgconn_ptr, passwd, user, calgo) + with cython.critical_section(self): + out = libpq.PQencryptPasswordConn(self._pgconn_ptr, passwd, user, calgo) if not out: raise e.OperationalError( f"password encryption failed: {self.get_error_message()}" @@ -651,15 +793,18 @@ cdef class PGconn: _check_supported("PQchangePassword", 170000) cdef libpq.PGresult *res - res = libpq.PQchangePassword(self._pgconn_ptr, user, passwd) + with cython.critical_section(self): + res = libpq.PQchangePassword(self._pgconn_ptr, user, passwd) if libpq.PQresultStatus(res) != ExecStatus.COMMAND_OK: raise e.OperationalError( f"password encryption failed: {self.get_error_message()}" ) def make_empty_result(self, int exec_status) -> PGresult: - cdef libpq.PGresult *rv = libpq.PQmakeEmptyPGresult( - self._pgconn_ptr, exec_status) + cdef libpq.PGresult *rv + with cython.critical_section(self): + rv = libpq.PQmakeEmptyPGresult( + self._pgconn_ptr, exec_status) if not rv: raise MemoryError("couldn't allocate empty PGresult") return PGresult._from_ptr(rv) @@ -672,8 +817,10 @@ cdef class PGconn: """ if libpq.PG_VERSION_NUM < 140000: return libpq.PQ_PIPELINE_OFF - cdef int status = libpq.PQpipelineStatus(self._pgconn_ptr) - return status + cdef libpq.PGpipelineStatus rv + with cython.critical_section(self): + rv = libpq.PQpipelineStatus(self._pgconn_ptr) + return rv def enter_pipeline_mode(self) -> None: """Enter pipeline mode. @@ -682,7 +829,10 @@ cdef class PGconn: mode. """ _check_supported("PQenterPipelineMode", 140000) - if libpq.PQenterPipelineMode(self._pgconn_ptr) != 1: + cdef int rv + with cython.critical_section(self): + rv = libpq.PQenterPipelineMode(self._pgconn_ptr) + if rv != 1: raise e.OperationalError("failed to enter pipeline mode") def exit_pipeline_mode(self) -> None: @@ -692,7 +842,10 @@ cdef class PGconn: mode. """ _check_supported("PQexitPipelineMode", 140000) - if libpq.PQexitPipelineMode(self._pgconn_ptr) != 1: + cdef int rv + with cython.critical_section(self): + rv = libpq.PQexitPipelineMode(self._pgconn_ptr) + if rv != 1: raise e.OperationalError(self.get_error_message()) def pipeline_sync(self) -> None: @@ -702,7 +855,9 @@ cdef class PGconn: or if sync failed. """ _check_supported("PQpipelineSync", 140000) - rv = libpq.PQpipelineSync(self._pgconn_ptr) + cdef int rv + with cython.critical_section(self): + rv = libpq.PQpipelineSync(self._pgconn_ptr) if rv == 0: raise e.OperationalError("connection not in pipeline mode") if rv != 1: @@ -714,39 +869,70 @@ cdef class PGconn: :raises ~e.OperationalError: if the flush request failed. """ _check_supported("PQsendFlushRequest ", 140000) - cdef int rv = libpq.PQsendFlushRequest(self._pgconn_ptr) + cdef int rv + with cython.critical_section(self): + rv = libpq.PQsendFlushRequest(self._pgconn_ptr) if rv == 0: raise e.OperationalError( f"flush request failed: {self.get_error_message()}") -cdef int _ensure_pgconn(PGconn pgconn) except 0: - if pgconn._pgconn_ptr is not NULL: - return 1 - - raise e.OperationalError("the connection is closed") - - -cdef char *_call_bytes(PGconn pgconn, conn_bytes_f func) except NULL: - """ - Call one of the pgconn libpq functions returning a bytes pointer. - """ - if not _ensure_pgconn(pgconn): - return NULL - cdef char *rv = func(pgconn._pgconn_ptr) +cdef void *_call_libpq_with_param( + PGconn self, + conn_f_with_param func, + const char *param +): + cdef void *rv = NULL + cdef libpq.PGconn *pgconn_ptr + with cython.critical_section(self): + pgconn_ptr = self._pgconn_ptr + if pgconn_ptr is not NULL: + rv = func(pgconn_ptr, param) + if pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") + return rv + + +cdef int _call_libpq_int(PGconn self, conn_int_f func): + cdef int rv + cdef libpq.PGconn *pgconn_ptr + with cython.critical_section(self): + pgconn_ptr = self._pgconn_ptr + if pgconn_ptr is not NULL: + rv = func(pgconn_ptr) + if pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") + return rv + + +cdef int _call_libpq_int_with_param( + PGconn self, + conn_int_f_with_param func, + const char *param +): + cdef int rv + cdef libpq.PGconn *pgconn_ptr + with cython.critical_section(self): + pgconn_ptr = self._pgconn_ptr + if pgconn_ptr is not NULL: + rv = func(pgconn_ptr, param) + if pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") + return rv + + +cdef bytes _call_libpq_bytes(PGconn self, conn_bytes_f func): + cdef char *rv + cdef libpq.PGconn *pgconn_ptr + with cython.critical_section(self): + pgconn_ptr = self._pgconn_ptr + if pgconn_ptr is not NULL: + rv = func(pgconn_ptr) + if pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") if rv is not NULL: return rv - else: - return b"" - - -cdef int _call_int(PGconn pgconn, conn_int_f func) except -2: - """ - Call one of the pgconn libpq functions returning an int. - """ - if not _ensure_pgconn(pgconn): - return -2 - return func(pgconn._pgconn_ptr) + return b"" cdef void notice_receiver(void *arg, const libpq.PGresult *res_ptr) noexcept with gil: diff --git a/tests/test_free_threading.py b/tests/test_free_threading.py index cfe80906f..b4b5122b1 100644 --- a/tests/test_free_threading.py +++ b/tests/test_free_threading.py @@ -1,3 +1,4 @@ +import time import threading from concurrent.futures import ThreadPoolExecutor @@ -182,3 +183,54 @@ def test_same_cursor_from_multiple_threads_no_crash(conn): future.result() cur.close() + + +@pytest.mark.slow +@pytest.mark.usefixtures("testctx") +def test_connection_finish_while_executing(conn): + with conn.cursor() as cur: + cur.execute("insert into testctx values (1)") + conn.commit() + + def closer(): + time.sleep(1) + conn.close() + + def reader(): + cur = conn.cursor() + try: + while True: + cur.execute("select id from testctx") + assert [row[0] for row in cur.fetchall()] == [1] + except Exception: + pass + + with ThreadPoolExecutor(max_workers=2) as tpe: + future2 = tpe.submit(reader) + future1 = tpe.submit(closer) + future1.result() + future2.result() + + +@pytest.mark.slow +def test_connection_close_race_condition(dsn): + conn = psycopg.connect(dsn, autocommit=True) + barrier = threading.Barrier(parties=2) + + def reader(): + barrier.wait() + messages = [conn.pgconn.error_message for _ in range(100)] + return messages + + def closer(): + barrier.wait() + conn.pgconn.finish() + + with ThreadPoolExecutor(max_workers=2) as tpe: + reader_future = tpe.submit(reader) + closer_future = tpe.submit(closer) + error_messages = reader_future.result() + closer_future.result() + + for error_message in error_messages: + assert error_message in (b"", b"connection pointer is NULL\n")