]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Use critical section to protect pgconn ptr
authorLysandros Nikolaou <lisandrosnik@gmail.com>
Fri, 14 Nov 2025 16:31:11 +0000 (17:31 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 25 Jan 2026 13:10:53 +0000 (13:10 +0000)
This PR adds critical sections to the PGconn class so that
race condition between calling into the libpq bindings and
calling `close` which sets the `_pgconn_ptr` to NULL are
eliminated.

Also add one high-level and one low-level tests that trigger TSAN
warnings without this PR.

psycopg_c/psycopg_c/pq.pxd
psycopg_c/psycopg_c/pq/pgconn.pyx
tests/test_free_threading.py

index 8a2dbd6343c11db141bbc0efa58c28a2ecb29fc5..1efe3bb01fe686c54f0267bfa1a767c2757cc01d 100644 (file)
@@ -13,8 +13,12 @@ cdef extern from * nogil:
 
 from psycopg_c.pq cimport libpq
 
-ctypedef char *(*conn_bytes_f) (const libpq.PGconn *)
-ctypedef int(*conn_int_f) (const libpq.PGconn *)
+ctypedef char *(*conn_bytes_f) (const libpq.PGconn *) noexcept nogil
+ctypedef int (*conn_int_f) (const libpq.PGconn *) noexcept nogil
+ctypedef void *(*conn_f_with_param) (const libpq.PGconn *, const char *) noexcept nogil
+ctypedef int (*conn_int_f_with_param) (
+    const libpq.PGconn *, const char *
+) noexcept nogil
 
 
 cdef class PGconn:
index 7590549376e0d2cb41e46894f1af2dce9ea3a8bf..68e1e6dc05e23979ac0e06723dd6601f7cdd9906 100644 (file)
@@ -17,6 +17,7 @@ cdef extern from * nogil:
     """
     pid_t getpid()
 
+cimport cython
 from libc.stdio cimport fdopen
 from cpython.mem cimport PyMem_Free, PyMem_Malloc
 from cpython.bytes cimport PyBytes_AsString
@@ -86,24 +87,32 @@ cdef class PGconn:
         return PGconn._from_ptr(pgconn)
 
     def connect_poll(self) -> int:
-        return _call_int(self, <conn_int_f>libpq.PQconnectPoll)
+        return _call_libpq_int(self, <conn_int_f>libpq.PQconnectPoll)
 
     def finish(self) -> None:
-        if self._pgconn_ptr is not NULL:
-            libpq.PQfinish(self._pgconn_ptr)
-            self._pgconn_ptr = NULL
+        with cython.critical_section(self):
+            if self._pgconn_ptr is not NULL:
+                libpq.PQfinish(self._pgconn_ptr)
+                self._pgconn_ptr = NULL
 
     @property
     def pgconn_ptr(self) -> int | None:
-        if self._pgconn_ptr:
-            return <long long><void *>self._pgconn_ptr
-        else:
-            return None
+        cdef long long ptr = -1
+        with cython.critical_section(self):
+            if self._pgconn_ptr:
+                ptr = <long long><void *>self._pgconn_ptr
+        return ptr if ptr != -1 else None
 
     @property
     def info(self) -> list[ConninfoOption]:
-        _ensure_pgconn(self)
-        cdef libpq.PQconninfoOption *opts = libpq.PQconninfo(self._pgconn_ptr)
+        cdef libpq.PQconninfoOption *opts
+        cdef libpq.PGconn *pgconn_ptr
+        with cython.critical_section(self):
+            pgconn_ptr = self._pgconn_ptr
+            if pgconn_ptr is not NULL:
+                opts = libpq.PQconninfo(pgconn_ptr)
+        if pgconn_ptr is NULL:
+            raise e.OperationalError("the connection is closed")
         if opts is NULL:
             raise MemoryError("couldn't allocate connection info")
         rv = _options_from_array(opts)
@@ -111,15 +120,23 @@ cdef class PGconn:
         return rv
 
     def reset(self) -> None:
-        _ensure_pgconn(self)
-        libpq.PQreset(self._pgconn_ptr)
+        cdef libpq.PGconn *pgconn_ptr
+        with cython.critical_section(self):
+            pgconn_ptr = self._pgconn_ptr
+            if pgconn_ptr is not NULL:
+                libpq.PQreset(pgconn_ptr)
+        if pgconn_ptr is NULL:
+            raise e.OperationalError("the connection is closed")
 
     def reset_start(self) -> None:
-        if not libpq.PQresetStart(self._pgconn_ptr):
+        cdef int rv
+        with cython.critical_section(self):
+            rv = libpq.PQresetStart(self._pgconn_ptr)
+        if not rv:
             raise e.OperationalError("couldn't reset connection")
 
     def reset_poll(self) -> int:
-        return _call_int(self, <conn_int_f>libpq.PQresetPoll)
+        return _call_libpq_int(self, <conn_int_f>libpq.PQresetPoll)
 
     @classmethod
     def ping(self, const char *conninfo) -> int:
@@ -127,51 +144,63 @@ cdef class PGconn:
 
     @property
     def db(self) -> bytes:
-        return _call_bytes(self, libpq.PQdb)
+        return _call_libpq_bytes(self, libpq.PQdb)
 
     @property
     def user(self) -> bytes:
-        return _call_bytes(self, libpq.PQuser)
+        return _call_libpq_bytes(self, libpq.PQuser)
 
     @property
     def password(self) -> bytes:
-        return _call_bytes(self, libpq.PQpass)
+        return _call_libpq_bytes(self, libpq.PQpass)
 
     @property
     def host(self) -> bytes:
-        return _call_bytes(self, libpq.PQhost)
+        return _call_libpq_bytes(self, libpq.PQhost)
 
     @property
     def hostaddr(self) -> bytes:
         _check_supported("PQhostaddr", 120000)
-        _ensure_pgconn(self)
-        cdef char *rv = libpq.PQhostaddr(self._pgconn_ptr)
+        cdef char *rv = NULL
+        cdef libpq.PGconn *pgconn_ptr
+        with cython.critical_section(self):
+            pgconn_ptr = self._pgconn_ptr
+            if pgconn_ptr is not NULL:
+                rv = libpq.PQhostaddr(pgconn_ptr)
+        if pgconn_ptr is NULL:
+            raise e.OperationalError("the connection is closed")
         assert rv is not NULL
         return rv
 
     @property
     def port(self) -> bytes:
-        return _call_bytes(self, libpq.PQport)
+        return _call_libpq_bytes(self, libpq.PQport)
 
     @property
     def tty(self) -> bytes:
-        return _call_bytes(self, libpq.PQtty)
+        return _call_libpq_bytes(self, libpq.PQtty)
 
     @property
     def options(self) -> bytes:
-        return _call_bytes(self, libpq.PQoptions)
+        return _call_libpq_bytes(self, libpq.PQoptions)
 
     @property
     def status(self) -> int:
-        return libpq.PQstatus(self._pgconn_ptr)
+        cdef libpq.ConnStatusType rv
+        with cython.critical_section(self):
+            rv = libpq.PQstatus(self._pgconn_ptr)
+        return rv
 
     @property
     def transaction_status(self) -> int:
-        return libpq.PQtransactionStatus(self._pgconn_ptr)
+        cdef libpq.PGTransactionStatusType rv
+        with cython.critical_section(self):
+            rv = libpq.PQtransactionStatus(self._pgconn_ptr)
+        return rv
 
     def parameter_status(self, const char *name) -> bytes | None:
-        _ensure_pgconn(self)
-        cdef const char *rv = libpq.PQparameterStatus(self._pgconn_ptr, name)
+        cdef const char *rv = <const char *>_call_libpq_with_param(
+            self, <conn_f_with_param>libpq.PQparameterStatus, name)
         if rv is not NULL:
             return rv
         else:
@@ -179,7 +208,10 @@ cdef class PGconn:
 
     @property
     def error_message(self) -> bytes:
-        return libpq.PQerrorMessage(self._pgconn_ptr)
+        cdef char *rv
+        with cython.critical_section(self):
+            rv = libpq.PQerrorMessage(self._pgconn_ptr)
+        return rv
 
     def get_error_message(self, encoding: str = "") -> str:
         return _clean_error_message(self.error_message, encoding or self._encoding)
@@ -187,72 +219,101 @@ cdef class PGconn:
     @property
     def _encoding(self) -> str:
         cdef const char *pgenc
-        if libpq.PQstatus(self._pgconn_ptr) == libpq.CONNECTION_OK:
-            pgenc = libpq.PQparameterStatus(self._pgconn_ptr, b"client_encoding")
-            if pgenc is NULL:
-                pgenc = b"UTF8"
+        cdef int status
+        with cython.critical_section(self):
+            status = libpq.PQstatus(self._pgconn_ptr)
+            if status == libpq.CONNECTION_OK:
+                pgenc = libpq.PQparameterStatus(self._pgconn_ptr, b"client_encoding")
+                if pgenc is NULL:
+                    pgenc = b"UTF8"
+        if status == libpq.CONNECTION_OK:
             return pg2pyenc(pgenc)
         else:
             return "utf-8"
 
     @property
     def protocol_version(self) -> int:
-        return _call_int(self, libpq.PQprotocolVersion)
+        return _call_libpq_int(self, <conn_int_f>libpq.PQprotocolVersion)
 
     @property
     def full_protocol_version(self) -> int:
         _check_supported("PQfullProtocolVersion", 180000)
-        _ensure_pgconn(self)
-        return libpq.PQfullProtocolVersion(self._pgconn_ptr)
+        cdef int rv
+        cdef libpq.PGconn *pgconn_ptr
+        with cython.critical_section(self):
+            pgconn_ptr = self._pgconn_ptr
+            if pgconn_ptr is not NULL:
+                rv = libpq.PQfullProtocolVersion(pgconn_ptr)
+        if pgconn_ptr is NULL:
+            raise e.OperationalError("the connection is closed")
+        return rv
 
     @property
     def server_version(self) -> int:
-        return _call_int(self, libpq.PQserverVersion)
+        return _call_libpq_int(self, <conn_int_f>libpq.PQserverVersion)
 
     @property
     def socket(self) -> int:
-        rv = _call_int(self, libpq.PQsocket)
+        cdef int rv = _call_libpq_int(self, <conn_int_f>libpq.PQsocket)
         if rv == -1:
             raise e.OperationalError("the connection is lost")
         return rv
 
     @property
     def backend_pid(self) -> int:
-        return _call_int(self, libpq.PQbackendPID)
+        return _call_libpq_int(self, <conn_int_f>libpq.PQbackendPID)
 
     @property
     def needs_password(self) -> bool:
-        return bool(libpq.PQconnectionNeedsPassword(self._pgconn_ptr))
+        cdef int rv
+        with cython.critical_section(self):
+            rv = libpq.PQconnectionNeedsPassword(self._pgconn_ptr)
+        return bool(rv)
 
     @property
     def used_password(self) -> bool:
-        return bool(libpq.PQconnectionUsedPassword(self._pgconn_ptr))
+        cdef int rv
+        with cython.critical_section(self):
+            rv = libpq.PQconnectionUsedPassword(self._pgconn_ptr)
+        return bool(rv)
 
     @property
     def used_gssapi(self) -> bool:
         _check_supported("PQconnectionUsedGSSAPI", 160000)
-        return bool(libpq.PQconnectionUsedGSSAPI(self._pgconn_ptr))
+        cdef int rv
+        with cython.critical_section(self):
+            rv = libpq.PQconnectionUsedGSSAPI(self._pgconn_ptr)
+        return bool(rv)
 
     @property
     def ssl_in_use(self) -> bool:
-        return bool(_call_int(self, <conn_int_f>libpq.PQsslInUse))
+        return bool(_call_libpq_int(self, <conn_int_f>libpq.PQsslInUse))
 
     def exec_(self, const char *command) -> PGresult:
-        _ensure_pgconn(self)
         cdef libpq.PGresult *pgresult
-        with nogil:
-            pgresult = libpq.PQexec(self._pgconn_ptr, command)
+        cdef libpq.PGconn *pgconn_ptr
+        with cython.critical_section(self):
+            with nogil:
+                pgconn_ptr = self._pgconn_ptr
+                if pgconn_ptr is not NULL:
+                    pgresult = libpq.PQexec(pgconn_ptr, command)
+        if pgconn_ptr is NULL:
+            raise e.OperationalError("the connection is closed")
         if pgresult is NULL:
             raise e.OperationalError(
                 f"executing query failed: {self.get_error_message()}")
-
         return PGresult._from_ptr(pgresult)
 
     def send_query(self, const char *command) -> None:
-        _ensure_pgconn(self)
         cdef int rv
-        with nogil:
-            rv = libpq.PQsendQuery(self._pgconn_ptr, command)
+        cdef libpq.PGconn *pgconn_ptr
+        with cython.critical_section(self):
+            with nogil:
+                pgconn_ptr = self._pgconn_ptr
+                if pgconn_ptr is not NULL:
+                    rv = libpq.PQsendQuery(pgconn_ptr, command)
+        if pgconn_ptr is NULL:
+            raise e.OperationalError("the connection is closed")
         if not rv:
             raise e.OperationalError(
                 f"sending query failed: {self.get_error_message()}")
@@ -265,8 +326,6 @@ cdef class PGconn:
         param_formats: Sequence[int] | None = None,
         int result_format = PqFormat.TEXT,
     ) -> PGresult:
-        _ensure_pgconn(self)
-
         cdef Py_ssize_t cnparams
         cdef libpq.Oid *ctypes
         cdef char *const *cvalues
@@ -275,12 +334,18 @@ cdef class PGconn:
         cnparams, ctypes, cvalues, clengths, cformats = _query_params_args(
             param_values, param_types, param_formats)
 
+        cdef libpq.PGconn *pgconn_ptr
         cdef libpq.PGresult *pgresult
-        with nogil:
-            pgresult = libpq.PQexecParams(
-                self._pgconn_ptr, command, <int>cnparams, ctypes,
-                <const char *const *>cvalues, clengths, cformats, result_format)
+        with cython.critical_section(self):
+            with nogil:
+                pgconn_ptr = self._pgconn_ptr
+                if pgconn_ptr is not NULL:
+                    pgresult = libpq.PQexecParams(
+                        pgconn_ptr, command, <int>cnparams, ctypes,
+                        <const char *const *>cvalues, clengths, cformats, result_format)
         _clear_query_params(ctypes, cvalues, clengths, cformats)
+        if pgconn_ptr is NULL:
+            raise e.OperationalError("the connection is closed")
         if pgresult is NULL:
             raise e.OperationalError(
                 f"executing query failed: {self.get_error_message()}")
@@ -294,8 +359,6 @@ cdef class PGconn:
         param_formats: Sequence[int] | None = None,
         int result_format = PqFormat.TEXT,
     ) -> None:
-        _ensure_pgconn(self)
-
         cdef Py_ssize_t cnparams
         cdef libpq.Oid *ctypes
         cdef char *const *cvalues
@@ -304,12 +367,18 @@ cdef class PGconn:
         cnparams, ctypes, cvalues, clengths, cformats = _query_params_args(
             param_values, param_types, param_formats)
 
+        cdef libpq.PGconn *pgconn_ptr
         cdef int rv
-        with nogil:
-            rv = libpq.PQsendQueryParams(
-                self._pgconn_ptr, command, <int>cnparams, ctypes,
-                <const char *const *>cvalues, clengths, cformats, result_format)
+        with cython.critical_section(self):
+            with nogil:
+                pgconn_ptr = self._pgconn_ptr
+                if pgconn_ptr is not NULL:
+                    rv = libpq.PQsendQueryParams(
+                        pgconn_ptr, command, <int>cnparams, ctypes,
+                        <const char *const *>cvalues, clengths, cformats, result_format)
         _clear_query_params(ctypes, cvalues, clengths, cformats)
+        if pgconn_ptr is NULL:
+            raise e.OperationalError("the connection is closed")
         if not rv:
             raise e.OperationalError(
                 f"sending query and params failed: {self.get_error_message()}"
@@ -321,8 +390,6 @@ cdef class PGconn:
         const char *command,
         param_types: Sequence[int] | None = None,
     ) -> None:
-        _ensure_pgconn(self)
-
         cdef int i
         cdef types_fast
         cdef Py_ssize_t nparams = 0
@@ -336,12 +403,18 @@ cdef class PGconn:
             for i in range(nparams):
                 atypes[i] = <object>PySequence_Fast_GET_ITEM(types_fast, i)
 
+        cdef libpq.PGconn *pgconn_ptr
         cdef int rv
-        with nogil:
-            rv = libpq.PQsendPrepare(
-                self._pgconn_ptr, name, command, <int>nparams, atypes
-            )
+        with cython.critical_section(self):
+            with nogil:
+                pgconn_ptr = self._pgconn_ptr
+                if pgconn_ptr is not NULL:
+                    rv = libpq.PQsendPrepare(
+                        pgconn_ptr, name, command, <int>nparams, atypes
+                    )
         PyMem_Free(atypes)
+        if pgconn_ptr is NULL:
+            raise e.OperationalError("the connection is closed")
         if not rv:
             raise e.OperationalError(
                 f"sending query and params failed: {self.get_error_message()}"
@@ -354,8 +427,6 @@ cdef class PGconn:
         param_formats: Sequence[int] | None = None,
         int result_format = PqFormat.TEXT,
     ) -> None:
-        _ensure_pgconn(self)
-
         cdef Py_ssize_t cnparams
         cdef libpq.Oid *ctypes
         cdef char *const *cvalues
@@ -364,12 +435,18 @@ cdef class PGconn:
         cnparams, ctypes, cvalues, clengths, cformats = _query_params_args(
             param_values, None, param_formats)
 
+        cdef libpq.PGconn *pgconn_ptr
         cdef int rv
-        with nogil:
-            rv = libpq.PQsendQueryPrepared(
-                self._pgconn_ptr, name, <int>cnparams, <const char *const *>cvalues,
-                clengths, cformats, result_format)
+        with cython.critical_section(self):
+            with nogil:
+                pgconn_ptr = self._pgconn_ptr
+                if pgconn_ptr is not NULL:
+                    rv = libpq.PQsendQueryPrepared(
+                        pgconn_ptr, name, <int>cnparams, <const char *const *>cvalues,
+                        clengths, cformats, result_format)
         _clear_query_params(ctypes, cvalues, clengths, cformats)
+        if pgconn_ptr is NULL:
+            raise e.OperationalError("the connection is closed")
         if not rv:
             raise e.OperationalError(
                 f"sending prepared query failed: {self.get_error_message()}"
@@ -381,8 +458,6 @@ cdef class PGconn:
         const char *command,
         param_types: Sequence[int] | None = None,
     ) -> PGresult:
-        _ensure_pgconn(self)
-
         cdef int i
         cdef types_fast
         cdef Py_ssize_t nparams = 0
@@ -396,11 +471,17 @@ cdef class PGconn:
             for i in range(nparams):
                 atypes[i] = <object>PySequence_Fast_GET_ITEM(types_fast, i)
 
+        cdef libpq.PGconn *pgconn_ptr
         cdef libpq.PGresult *rv
-        with nogil:
-            rv = libpq.PQprepare(
-                self._pgconn_ptr, name, command, <int>nparams, atypes)
+        with cython.critical_section(self):
+            with nogil:
+                pgconn_ptr = self._pgconn_ptr
+                if pgconn_ptr is not NULL:
+                    rv = libpq.PQprepare(
+                        pgconn_ptr, name, command, <int>nparams, atypes)
         PyMem_Free(atypes)
+        if pgconn_ptr is NULL:
+            raise e.OperationalError("the connection is closed")
         if rv is NULL:
             raise e.OperationalError(
                 f"preparing query failed: {self.get_error_message()}")
@@ -413,8 +494,6 @@ cdef class PGconn:
         param_formats: Sequence[int] | None = None,
         int result_format = PqFormat.TEXT,
     ) -> PGresult:
-        _ensure_pgconn(self)
-
         cdef Py_ssize_t cnparams
         cdef libpq.Oid *ctypes
         cdef char *const *cvalues
@@ -423,14 +502,19 @@ cdef class PGconn:
         cnparams, ctypes, cvalues, clengths, cformats = _query_params_args(
             param_values, None, param_formats)
 
+        cdef libpq.PGconn *pgconn_ptr
         cdef libpq.PGresult *rv
-        with nogil:
-            rv = libpq.PQexecPrepared(
-                self._pgconn_ptr, name, <int>cnparams,
-                <const char *const *>cvalues,
-                clengths, cformats, result_format)
-
+        with cython.critical_section(self):
+            with nogil:
+                pgconn_ptr = self._pgconn_ptr
+                if pgconn_ptr is not NULL:
+                    rv = libpq.PQexecPrepared(
+                        pgconn_ptr, name, <int>cnparams,
+                        <const char *const *>cvalues,
+                        clengths, cformats, result_format)
         _clear_query_params(ctypes, cvalues, clengths, cformats)
+        if pgconn_ptr is NULL:
+            raise e.OperationalError("the connection is closed")
         if rv is NULL:
             raise e.OperationalError(
                 f"executing prepared query failed: {self.get_error_message()}"
@@ -438,8 +522,8 @@ cdef class PGconn:
         return PGresult._from_ptr(rv)
 
     def describe_prepared(self, const char *name) -> PGresult:
-        _ensure_pgconn(self)
-        cdef libpq.PGresult *rv = libpq.PQdescribePrepared(self._pgconn_ptr, name)
+        cdef libpq.PGresult *rv = <libpq.PGresult *>_call_libpq_with_param(
+            self, <conn_f_with_param>libpq.PQdescribePrepared, name)
         if rv is NULL:
             raise e.OperationalError(
                 f"describe prepared failed: {self.get_error_message()}"
@@ -447,16 +531,16 @@ cdef class PGconn:
         return PGresult._from_ptr(rv)
 
     def send_describe_prepared(self, const char *name) -> None:
-        _ensure_pgconn(self)
-        cdef int rv = libpq.PQsendDescribePrepared(self._pgconn_ptr, name)
+        cdef int rv = _call_libpq_int_with_param(
+            self, <conn_int_f_with_param>libpq.PQsendDescribePrepared, name)
         if not rv:
             raise e.OperationalError(
                 f"sending describe prepared failed: {self.get_error_message()}"
             )
 
     def describe_portal(self, const char *name) -> PGresult:
-        _ensure_pgconn(self)
-        cdef libpq.PGresult *rv = libpq.PQdescribePortal(self._pgconn_ptr, name)
+        cdef libpq.PGresult *rv = <libpq.PGresult *>_call_libpq_with_param(
+            self, <conn_f_with_param>libpq.PQdescribePortal, name)
         if rv is NULL:
             raise e.OperationalError(
                 f"describe prepared failed: {self.get_error_message()}"
@@ -464,8 +548,8 @@ cdef class PGconn:
         return PGresult._from_ptr(rv)
 
     def send_describe_portal(self, const char *name) -> None:
-        _ensure_pgconn(self)
-        cdef int rv = libpq.PQsendDescribePortal(self._pgconn_ptr, name)
+        cdef int rv = _call_libpq_int_with_param(
+            self, <conn_int_f_with_param>libpq.PQsendDescribePortal, name)
         if not rv:
             raise e.OperationalError(
                 f"sending describe prepared failed: {self.get_error_message()}"
@@ -473,8 +557,14 @@ cdef class PGconn:
 
     def close_prepared(self, const char *name) -> PGresult:
         _check_supported("PQclosePrepared", 170000)
-        _ensure_pgconn(self)
-        cdef libpq.PGresult *rv = libpq.PQclosePrepared(self._pgconn_ptr, name)
+        cdef libpq.PGresult *rv
+        cdef libpq.PGconn *pgconn_ptr
+        with cython.critical_section(self):
+            pgconn_ptr = self._pgconn_ptr
+            if pgconn_ptr is not NULL:
+                rv = libpq.PQclosePrepared(pgconn_ptr, name)
+        if pgconn_ptr is NULL:
+            raise e.OperationalError("the connection is closed")
         if rv is NULL:
             raise e.OperationalError(
                 f"close prepared failed: {self.get_error_message()}"
@@ -483,8 +573,14 @@ cdef class PGconn:
 
     def send_close_prepared(self, const char *name) -> None:
         _check_supported("PQsendClosePrepared", 170000)
-        _ensure_pgconn(self)
-        cdef int rv = libpq.PQsendClosePrepared(self._pgconn_ptr, name)
+        cdef int rv
+        cdef libpq.PGconn *pgconn_ptr
+        with cython.critical_section(self):
+            pgconn_ptr = self._pgconn_ptr
+            if pgconn_ptr is not NULL:
+                rv = libpq.PQsendClosePrepared(pgconn_ptr, name)
+        if pgconn_ptr is NULL:
+            raise e.OperationalError("the connection is closed")
         if not rv:
             raise e.OperationalError(
                 f"sending close prepared failed: {self.get_error_message()}"
@@ -492,8 +588,14 @@ cdef class PGconn:
 
     def close_portal(self, const char *name) -> PGresult:
         _check_supported("PQclosePortal", 170000)
-        _ensure_pgconn(self)
-        cdef libpq.PGresult *rv = libpq.PQclosePortal(self._pgconn_ptr, name)
+        cdef libpq.PGresult *rv
+        cdef libpq.PGconn *pgconn_ptr
+        with cython.critical_section(self):
+            pgconn_ptr = self._pgconn_ptr
+            if pgconn_ptr is not NULL:
+                rv = libpq.PQclosePortal(pgconn_ptr, name)
+        if pgconn_ptr is NULL:
+            raise e.OperationalError("the connection is closed")
         if rv is NULL:
             raise e.OperationalError(
                 f"close prepared failed: {self.get_error_message()}"
@@ -502,73 +604,104 @@ cdef class PGconn:
 
     def send_close_portal(self, const char *name) -> None:
         _check_supported("PQsendClosePortal", 170000)
-        _ensure_pgconn(self)
-        cdef int rv = libpq.PQsendClosePortal(self._pgconn_ptr, name)
+        cdef int rv
+        cdef libpq.PGconn *pgconn_ptr
+        with cython.critical_section(self):
+            pgconn_ptr = self._pgconn_ptr
+            if pgconn_ptr is not NULL:
+                rv = libpq.PQsendClosePortal(pgconn_ptr, name)
+        if pgconn_ptr is NULL:
+            raise e.OperationalError("the connection is closed")
         if not rv:
             raise e.OperationalError(
                 f"sending close prepared failed: {self.get_error_message()}"
             )
 
     def get_result(self) -> "PGresult" | None:
-        cdef libpq.PGresult *pgresult = libpq.PQgetResult(self._pgconn_ptr)
+        cdef libpq.PGresult *pgresult
+        with cython.critical_section(self):
+            pgresult = libpq.PQgetResult(self._pgconn_ptr)
         if pgresult is NULL:
             return None
         return PGresult._from_ptr(pgresult)
 
     def consume_input(self) -> None:
-        if 1 != libpq.PQconsumeInput(self._pgconn_ptr):
+        cdef int rv
+        with cython.critical_section(self):
+            rv = libpq.PQconsumeInput(self._pgconn_ptr)
+        if 1 != rv:
             raise e.OperationalError(
                 f"consuming input failed: {self.get_error_message()}")
 
     def is_busy(self) -> int:
         cdef int rv
-        with nogil:
-            rv = libpq.PQisBusy(self._pgconn_ptr)
+        with cython.critical_section(self):
+            with nogil:
+                rv = libpq.PQisBusy(self._pgconn_ptr)
         return rv
 
     @property
     def nonblocking(self) -> int:
-        return libpq.PQisnonblocking(self._pgconn_ptr)
+        cdef int rv
+        with cython.critical_section(self):
+            rv = libpq.PQisnonblocking(self._pgconn_ptr)
+        return rv
 
     @nonblocking.setter
     def nonblocking(self, int arg) -> None:
-        if 0 > libpq.PQsetnonblocking(self._pgconn_ptr, arg):
+        cdef int rv
+        with cython.critical_section(self):
+            rv = libpq.PQsetnonblocking(self._pgconn_ptr, arg)
+        if 0 > rv:
             raise e.OperationalError(
                 f"setting nonblocking failed: {self.get_error_message()}")
 
     cpdef int flush(self) except -1:
-        if self._pgconn_ptr == NULL:
-            raise e.OperationalError("flushing failed: the connection is closed")
-        cdef int rv = libpq.PQflush(self._pgconn_ptr)
+        cdef int rv
+        with cython.critical_section(self):
+            if self._pgconn_ptr is NULL:
+                raise e.OperationalError("flushing failed: the connection is closed")
+            rv = libpq.PQflush(self._pgconn_ptr)
         if rv < 0:
             raise e.OperationalError(f"flushing failed: {self.get_error_message()}")
         return rv
 
     def set_single_row_mode(self) -> None:
-        if not libpq.PQsetSingleRowMode(self._pgconn_ptr):
+        cdef int rv
+        with cython.critical_section(self):
+            rv = libpq.PQsetSingleRowMode(self._pgconn_ptr)
+        if not rv:
             raise e.OperationalError("setting single row mode failed")
 
     def set_chunked_rows_mode(self, size: int) -> None:
-        if not libpq.PQsetChunkedRowsMode(self._pgconn_ptr, size):
+        cdef int rv
+        with cython.critical_section(self):
+            rv = libpq.PQsetChunkedRowsMode(self._pgconn_ptr, size)
+        if not rv:
             raise e.OperationalError("setting chunked rows mode failed")
 
     def cancel_conn(self) -> PGcancelConn:
         _check_supported("PQcancelCreate", 170000)
-        cdef libpq.PGcancelConn *ptr = libpq.PQcancelCreate(self._pgconn_ptr)
+        cdef libpq.PGcancelConn *ptr
+        with cython.critical_section(self):
+            ptr = libpq.PQcancelCreate(self._pgconn_ptr)
         if not ptr:
             raise e.OperationalError("couldn't create cancelConn object")
         return PGcancelConn._from_ptr(ptr)
 
     def get_cancel(self) -> PGcancel:
-        cdef libpq.PGcancel *ptr = libpq.PQgetCancel(self._pgconn_ptr)
+        cdef libpq.PGcancel *ptr
+        with cython.critical_section(self):
+            ptr = libpq.PQgetCancel(self._pgconn_ptr)
         if not ptr:
             raise e.OperationalError("couldn't create cancel object")
         return PGcancel._from_ptr(ptr)
 
     cpdef object notifies(self):
         cdef libpq.PGnotify *ptr
-        with nogil:
-            ptr = libpq.PQnotifies(self._pgconn_ptr)
+        with cython.critical_section(self):
+            with nogil:
+                ptr = libpq.PQnotifies(self._pgconn_ptr)
         if ptr:
             ret = PGnotify(ptr.relname, ptr.be_pid, ptr.extra)
             libpq.PQfreemem(ptr)
@@ -582,7 +715,8 @@ cdef class PGconn:
         cdef Py_ssize_t length
 
         _buffer_as_string_and_size(buffer, &cbuffer, &length)
-        rv = libpq.PQputCopyData(self._pgconn_ptr, cbuffer, <int>length)
+        with cython.critical_section(self):
+            rv = libpq.PQputCopyData(self._pgconn_ptr, cbuffer, <int>length)
         if rv < 0:
             raise e.OperationalError(
                 f"sending copy data failed: {self.get_error_message()}")
@@ -593,7 +727,9 @@ cdef class PGconn:
         cdef const char *cerr = NULL
         if error is not None:
             cerr = PyBytes_AsString(error)
-        rv = libpq.PQputCopyEnd(self._pgconn_ptr, cerr)
+
+        with cython.critical_section(self):
+            rv = libpq.PQputCopyEnd(self._pgconn_ptr, cerr)
         if rv < 0:
             raise e.OperationalError(
                 f"sending copy end failed: {self.get_error_message()}")
@@ -602,7 +738,9 @@ cdef class PGconn:
     def get_copy_data(self, int async_) -> tuple[int, memoryview]:
         cdef char *buffer_ptr = NULL
         cdef int nbytes
-        nbytes = libpq.PQgetCopyData(self._pgconn_ptr, &buffer_ptr, async_)
+
+        with cython.critical_section(self):
+            nbytes = libpq.PQgetCopyData(self._pgconn_ptr, &buffer_ptr, async_)
         if nbytes == -2:
             raise e.OperationalError(
                 f"receiving copy data failed: {self.get_error_message()}")
@@ -617,14 +755,17 @@ cdef class PGconn:
         if sys.platform != "linux":
             raise e.NotSupportedError("currently only supported on Linux")
         stream = fdopen(fileno, b"w")
-        libpq.PQtrace(self._pgconn_ptr, stream)
+        with cython.critical_section(self):
+            libpq.PQtrace(self._pgconn_ptr, stream)
 
     def set_trace_flags(self, flags: Trace) -> None:
         _check_supported("PQsetTraceFlags", 140000)
-        libpq.PQsetTraceFlags(self._pgconn_ptr, flags)
+        with cython.critical_section(self):
+            libpq.PQsetTraceFlags(self._pgconn_ptr, flags)
 
     def untrace(self) -> None:
-        libpq.PQuntrace(self._pgconn_ptr)
+        with cython.critical_section(self):
+            libpq.PQuntrace(self._pgconn_ptr)
 
     def encrypt_password(
         self, const char *passwd, const char *user, algorithm = None
@@ -635,7 +776,8 @@ cdef class PGconn:
         cdef const char *calgo = NULL
         if algorithm:
             calgo = algorithm
-        out = libpq.PQencryptPasswordConn(self._pgconn_ptr, passwd, user, calgo)
+        with cython.critical_section(self):
+            out = libpq.PQencryptPasswordConn(self._pgconn_ptr, passwd, user, calgo)
         if not out:
             raise e.OperationalError(
                 f"password encryption failed: {self.get_error_message()}"
@@ -651,15 +793,18 @@ cdef class PGconn:
         _check_supported("PQchangePassword", 170000)
 
         cdef libpq.PGresult *res
-        res = libpq.PQchangePassword(self._pgconn_ptr, user, passwd)
+        with cython.critical_section(self):
+            res = libpq.PQchangePassword(self._pgconn_ptr, user, passwd)
         if libpq.PQresultStatus(res) != ExecStatus.COMMAND_OK:
             raise e.OperationalError(
                 f"password encryption failed: {self.get_error_message()}"
             )
 
     def make_empty_result(self, int exec_status) -> PGresult:
-        cdef libpq.PGresult *rv = libpq.PQmakeEmptyPGresult(
-            self._pgconn_ptr, <libpq.ExecStatusType>exec_status)
+        cdef libpq.PGresult *rv
+        with cython.critical_section(self):
+            rv = libpq.PQmakeEmptyPGresult(
+                self._pgconn_ptr, <libpq.ExecStatusType>exec_status)
         if not rv:
             raise MemoryError("couldn't allocate empty PGresult")
         return PGresult._from_ptr(rv)
@@ -672,8 +817,10 @@ cdef class PGconn:
         """
         if libpq.PG_VERSION_NUM < 140000:
             return libpq.PQ_PIPELINE_OFF
-        cdef int status = libpq.PQpipelineStatus(self._pgconn_ptr)
-        return status
+        cdef libpq.PGpipelineStatus rv
+        with cython.critical_section(self):
+            rv = libpq.PQpipelineStatus(self._pgconn_ptr)
+        return rv
 
     def enter_pipeline_mode(self) -> None:
         """Enter pipeline mode.
@@ -682,7 +829,10 @@ cdef class PGconn:
             mode.
         """
         _check_supported("PQenterPipelineMode", 140000)
-        if libpq.PQenterPipelineMode(self._pgconn_ptr) != 1:
+        cdef int rv
+        with cython.critical_section(self):
+            rv = libpq.PQenterPipelineMode(self._pgconn_ptr)
+        if rv != 1:
             raise e.OperationalError("failed to enter pipeline mode")
 
     def exit_pipeline_mode(self) -> None:
@@ -692,7 +842,10 @@ cdef class PGconn:
             mode.
         """
         _check_supported("PQexitPipelineMode", 140000)
-        if libpq.PQexitPipelineMode(self._pgconn_ptr) != 1:
+        cdef int rv
+        with cython.critical_section(self):
+            rv = libpq.PQexitPipelineMode(self._pgconn_ptr)
+        if rv != 1:
             raise e.OperationalError(self.get_error_message())
 
     def pipeline_sync(self) -> None:
@@ -702,7 +855,9 @@ cdef class PGconn:
             or if sync failed.
         """
         _check_supported("PQpipelineSync", 140000)
-        rv = libpq.PQpipelineSync(self._pgconn_ptr)
+        cdef int rv
+        with cython.critical_section(self):
+            rv = libpq.PQpipelineSync(self._pgconn_ptr)
         if rv == 0:
             raise e.OperationalError("connection not in pipeline mode")
         if rv != 1:
@@ -714,39 +869,70 @@ cdef class PGconn:
         :raises ~e.OperationalError: if the flush request failed.
         """
         _check_supported("PQsendFlushRequest ", 140000)
-        cdef int rv = libpq.PQsendFlushRequest(self._pgconn_ptr)
+        cdef int rv
+        with cython.critical_section(self):
+            rv = libpq.PQsendFlushRequest(self._pgconn_ptr)
         if rv == 0:
             raise e.OperationalError(
                 f"flush request failed: {self.get_error_message()}")
 
 
-cdef int _ensure_pgconn(PGconn pgconn) except 0:
-    if pgconn._pgconn_ptr is not NULL:
-        return 1
-
-    raise e.OperationalError("the connection is closed")
-
-
-cdef char *_call_bytes(PGconn pgconn, conn_bytes_f func) except NULL:
-    """
-    Call one of the pgconn libpq functions returning a bytes pointer.
-    """
-    if not _ensure_pgconn(pgconn):
-        return NULL
-    cdef char *rv = func(pgconn._pgconn_ptr)
+cdef void *_call_libpq_with_param(
+    PGconn self,
+    conn_f_with_param func,
+    const char *param
+):
+    cdef void *rv = NULL
+    cdef libpq.PGconn *pgconn_ptr
+    with cython.critical_section(self):
+        pgconn_ptr = self._pgconn_ptr
+        if pgconn_ptr is not NULL:
+            rv = func(pgconn_ptr, param)
+    if pgconn_ptr is NULL:
+        raise e.OperationalError("the connection is closed")
+    return rv
+
+
+cdef int _call_libpq_int(PGconn self, conn_int_f func):
+    cdef int rv
+    cdef libpq.PGconn *pgconn_ptr
+    with cython.critical_section(self):
+        pgconn_ptr = self._pgconn_ptr
+        if pgconn_ptr is not NULL:
+            rv = func(pgconn_ptr)
+    if pgconn_ptr is NULL:
+        raise e.OperationalError("the connection is closed")
+    return rv
+
+
+cdef int _call_libpq_int_with_param(
+    PGconn self,
+    conn_int_f_with_param func,
+    const char *param
+):
+    cdef int rv
+    cdef libpq.PGconn *pgconn_ptr
+    with cython.critical_section(self):
+        pgconn_ptr = self._pgconn_ptr
+        if pgconn_ptr is not NULL:
+            rv = func(pgconn_ptr, param)
+    if pgconn_ptr is NULL:
+        raise e.OperationalError("the connection is closed")
+    return rv
+
+
+cdef bytes _call_libpq_bytes(PGconn self, conn_bytes_f func):
+    cdef char *rv
+    cdef libpq.PGconn *pgconn_ptr
+    with cython.critical_section(self):
+        pgconn_ptr = self._pgconn_ptr
+        if pgconn_ptr is not NULL:
+            rv = func(pgconn_ptr)
+    if pgconn_ptr is NULL:
+        raise e.OperationalError("the connection is closed")
     if rv is not NULL:
         return rv
-    else:
-        return b""
-
-
-cdef int _call_int(PGconn pgconn, conn_int_f func) except -2:
-    """
-    Call one of the pgconn libpq functions returning an int.
-    """
-    if not _ensure_pgconn(pgconn):
-        return -2
-    return func(pgconn._pgconn_ptr)
+    return b""
 
 
 cdef void notice_receiver(void *arg, const libpq.PGresult *res_ptr) noexcept with gil:
index cfe80906f3ca5f77ec1e2c54c7c5a17689293bea..b4b5122b17d310446670bb5d0798c34971861b9f 100644 (file)
@@ -1,3 +1,4 @@
+import time
 import threading
 from concurrent.futures import ThreadPoolExecutor
 
@@ -182,3 +183,54 @@ def test_same_cursor_from_multiple_threads_no_crash(conn):
             future.result()
 
     cur.close()
+
+
+@pytest.mark.slow
+@pytest.mark.usefixtures("testctx")
+def test_connection_finish_while_executing(conn):
+    with conn.cursor() as cur:
+        cur.execute("insert into testctx values (1)")
+    conn.commit()
+
+    def closer():
+        time.sleep(1)
+        conn.close()
+
+    def reader():
+        cur = conn.cursor()
+        try:
+            while True:
+                cur.execute("select id from testctx")
+                assert [row[0] for row in cur.fetchall()] == [1]
+        except Exception:
+            pass
+
+    with ThreadPoolExecutor(max_workers=2) as tpe:
+        future2 = tpe.submit(reader)
+        future1 = tpe.submit(closer)
+        future1.result()
+        future2.result()
+
+
+@pytest.mark.slow
+def test_connection_close_race_condition(dsn):
+    conn = psycopg.connect(dsn, autocommit=True)
+    barrier = threading.Barrier(parties=2)
+
+    def reader():
+        barrier.wait()
+        messages = [conn.pgconn.error_message for _ in range(100)]
+        return messages
+
+    def closer():
+        barrier.wait()
+        conn.pgconn.finish()
+
+    with ThreadPoolExecutor(max_workers=2) as tpe:
+        reader_future = tpe.submit(reader)
+        closer_future = tpe.submit(closer)
+        error_messages = reader_future.result()
+        closer_future.result()
+
+    for error_message in error_messages:
+        assert error_message in (b"", b"connection pointer is NULL\n")