]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-143756: Avoid borrowed reference in SSL code (gh-143816)
authorSam Gross <colesbury@gmail.com>
Thu, 22 Jan 2026 19:02:48 +0000 (14:02 -0500)
committerGitHub <noreply@github.com>
Thu, 22 Jan 2026 19:02:48 +0000 (14:02 -0500)
GET_SOCKET() returned a borrowed reference, which was potentially
unsafe. Also, refactor out some common code.

Modules/_ssl.c

index e240b889d86a2d5203c7bd9c7140ac47e0bbc910..22865bdfc3f727fee44b4761bff7cfbacc41b005 100644 (file)
@@ -423,26 +423,6 @@ typedef enum {
 #define ERRSTR1(x,y,z) (x ":" y ": " z)
 #define ERRSTR(x) ERRSTR1("_ssl.c", Py_STRINGIFY(__LINE__), x)
 
-// Get the socket from a PySSLSocket, if it has one.
-// Return a borrowed reference.
-static inline PySocketSockObject* GET_SOCKET(PySSLSocket *obj) {
-    if (obj->Socket) {
-        PyObject *sock;
-        if (PyWeakref_GetRef(obj->Socket, &sock)) {
-            // GET_SOCKET() returns a borrowed reference
-            Py_DECREF(sock);
-        }
-        else {
-            // dead weak reference
-            sock = Py_None;
-        }
-        return (PySocketSockObject *)sock;  // borrowed reference
-    }
-    else {
-        return NULL;
-    }
-}
-
 /* If sock is NULL, use a timeout of 0 second */
 #define GET_SOCKET_TIMEOUT(sock) \
     ((sock != NULL) ? (sock)->sock_timeout : 0)
@@ -794,6 +774,35 @@ _ssl_deprecated(const char* msg, int stacklevel) {
 #define PY_SSL_DEPRECATED(name, stacklevel, ret) \
     if (_ssl_deprecated((name), (stacklevel)) == -1) return (ret)
 
+// Get the socket from a PySSLSocket, if it has one.
+// Stores a strong reference in out_sock.
+static int
+get_socket(PySSLSocket *obj, PySocketSockObject **out_sock,
+           const char *filename, int lineno)
+{
+    if (!obj->Socket) {
+        *out_sock = NULL;
+        return 0;
+    }
+    PySocketSockObject *sock;
+    int res = PyWeakref_GetRef(obj->Socket, (PyObject **)&sock);
+    if (res == 0 || sock->sock_fd == INVALID_SOCKET) {
+        _setSSLError(get_state_sock(obj),
+                     "Underlying socket connection gone",
+                     PY_SSL_ERROR_NO_SOCKET, filename, lineno);
+        *out_sock = NULL;
+        return -1;
+    }
+    if (sock != NULL) {
+        /* just in case the blocking state of the socket has been changed */
+        int nonblocking = (sock->sock_timeout >= 0);
+        BIO_set_nbio(SSL_get_rbio(obj->ssl), nonblocking);
+        BIO_set_nbio(SSL_get_wbio(obj->ssl), nonblocking);
+    }
+    *out_sock = sock;
+    return res;
+}
+
 /*
  * SSL objects
  */
@@ -1021,24 +1030,13 @@ _ssl__SSLSocket_do_handshake_impl(PySSLSocket *self)
     int ret;
     _PySSLError err;
     PyObject *exc = NULL;
-    int sockstate, nonblocking;
-    PySocketSockObject *sock = GET_SOCKET(self);
+    int sockstate;
     PyTime_t timeout, deadline = 0;
     int has_timeout;
 
-    if (sock) {
-        if (((PyObject*)sock) == Py_None) {
-            _setSSLError(get_state_sock(self),
-                         "Underlying socket connection gone",
-                         PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
-            return NULL;
-        }
-        Py_INCREF(sock);
-
-        /* just in case the blocking state of the socket has been changed */
-        nonblocking = (sock->sock_timeout >= 0);
-        BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
-        BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
+    PySocketSockObject *sock = NULL;
+    if (get_socket(self, &sock, __FILE__, __LINE__) < 0) {
+        return NULL;
     }
 
     timeout = GET_SOCKET_TIMEOUT(sock);
@@ -2610,22 +2608,12 @@ _ssl__SSLSocket_sendfile_impl(PySSLSocket *self, int fd, Py_off_t offset,
     int sockstate;
     _PySSLError err;
     PyObject *exc = NULL;
-    PySocketSockObject *sock = GET_SOCKET(self);
     PyTime_t timeout, deadline = 0;
     int has_timeout;
 
-    if (sock != NULL) {
-        if ((PyObject *)sock == Py_None) {
-            _setSSLError(get_state_sock(self),
-                         "Underlying socket connection gone",
-                         PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
-            return NULL;
-        }
-        Py_INCREF(sock);
-        /* just in case the blocking state of the socket has been changed */
-        int nonblocking = (sock->sock_timeout >= 0);
-        BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
-        BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
+    PySocketSockObject *sock = NULL;
+    if (get_socket(self, &sock, __FILE__, __LINE__) < 0) {
+        return NULL;
     }
 
     timeout = GET_SOCKET_TIMEOUT(sock);
@@ -2747,26 +2735,12 @@ _ssl__SSLSocket_write_impl(PySSLSocket *self, Py_buffer *b)
     int sockstate;
     _PySSLError err;
     PyObject *exc = NULL;
-    int nonblocking;
-    PySocketSockObject *sock = GET_SOCKET(self);
     PyTime_t timeout, deadline = 0;
     int has_timeout;
 
-    if (sock != NULL) {
-        if (((PyObject*)sock) == Py_None) {
-            _setSSLError(get_state_sock(self),
-                         "Underlying socket connection gone",
-                         PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
-            return NULL;
-        }
-        Py_INCREF(sock);
-    }
-
-    if (sock != NULL) {
-        /* just in case the blocking state of the socket has been changed */
-        nonblocking = (sock->sock_timeout >= 0);
-        BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
-        BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
+    PySocketSockObject *sock = NULL;
+    if (get_socket(self, &sock, __FILE__, __LINE__) < 0) {
+        return NULL;
     }
 
     timeout = GET_SOCKET_TIMEOUT(sock);
@@ -2896,8 +2870,6 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len,
     int sockstate;
     _PySSLError err;
     PyObject *exc = NULL;
-    int nonblocking;
-    PySocketSockObject *sock = GET_SOCKET(self);
     PyTime_t timeout, deadline = 0;
     int has_timeout;
 
@@ -2906,14 +2878,9 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len,
         return NULL;
     }
 
-    if (sock != NULL) {
-        if (((PyObject*)sock) == Py_None) {
-            _setSSLError(get_state_sock(self),
-                         "Underlying socket connection gone",
-                         PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
-            return NULL;
-        }
-        Py_INCREF(sock);
+    PySocketSockObject *sock = NULL;
+    if (get_socket(self, &sock, __FILE__, __LINE__) < 0) {
+        return NULL;
     }
 
     if (!group_right_1) {
@@ -2944,13 +2911,6 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len,
         }
     }
 
-    if (sock != NULL) {
-        /* just in case the blocking state of the socket has been changed */
-        nonblocking = (sock->sock_timeout >= 0);
-        BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
-        BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
-    }
-
     timeout = GET_SOCKET_TIMEOUT(sock);
     has_timeout = (timeout > 0);
     if (has_timeout)
@@ -3041,26 +3001,14 @@ _ssl__SSLSocket_shutdown_impl(PySSLSocket *self)
 {
     _PySSLError err;
     PyObject *exc = NULL;
-    int sockstate, nonblocking, ret;
+    int sockstate, ret;
     int zeros = 0;
-    PySocketSockObject *sock = GET_SOCKET(self);
     PyTime_t timeout, deadline = 0;
     int has_timeout;
 
-    if (sock != NULL) {
-        /* Guard against closed socket */
-        if ((((PyObject*)sock) == Py_None) || (sock->sock_fd == INVALID_SOCKET)) {
-            _setSSLError(get_state_sock(self),
-                         "Underlying socket connection gone",
-                         PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
-            return NULL;
-        }
-        Py_INCREF(sock);
-
-        /* Just in case the blocking state of the socket has been changed */
-        nonblocking = (sock->sock_timeout >= 0);
-        BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
-        BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
+    PySocketSockObject *sock = NULL;
+    if (get_socket(self, &sock, __FILE__, __LINE__) < 0) {
+        return NULL;
     }
 
     timeout = GET_SOCKET_TIMEOUT(sock);