]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
SF patch 676472 by Geoff Talvola, reviewed by Ben Laurie.
authorGuido van Rossum <guido@python.org>
Fri, 31 Jan 2003 18:13:18 +0000 (18:13 +0000)
committerGuido van Rossum <guido@python.org>
Fri, 31 Jan 2003 18:13:18 +0000 (18:13 +0000)
Geoff writes:
  This is yet another patch to _ssl.c that sets the
  underlying BIO to non-blocking if the socket being
  wrapped is non-blocking. It also correctly loops when
  SSL_connect, SSL_write, or SSL_read indicates that it
  needs to read or write more bytes.

  This seems to fix bug #673797 which was not fixed by my
  previous patch.

Modules/_ssl.c

index 0a42fe7f162f3b6a830023655810a4c897a6d183..cfcb8a5763b495c8736f4bd7bf2fe3a646915d1b 100644 (file)
@@ -168,6 +168,8 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file)
        PySSLObject *self;
        char *errstr = NULL;
        int ret;
+       int err;
+       int timedout;
 
        self = PyObject_New(PySSLObject, &PySSL_Type); /* Create new object */
        if (self == NULL){
@@ -220,14 +222,38 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file)
        self->ssl = SSL_new(self->ctx); /* New ssl struct */
        Py_END_ALLOW_THREADS
        SSL_set_fd(self->ssl, Sock->sock_fd);   /* Set the socket for SSL */
+
+       /* If the socket is is non-blocking mode or timeout mode, set the BIO
+        * to non-blocking mode (blocking is the default)
+        */
+       if (Sock->sock_timeout >= 0.0) {
+               /* Set both the read and write BIO's to non-blocking mode */
+               BIO_set_nbio(SSL_get_rbio(self->ssl), 1);
+               BIO_set_nbio(SSL_get_wbio(self->ssl), 1);
+       }
+
        Py_BEGIN_ALLOW_THREADS
        SSL_set_connect_state(self->ssl);
-
+       Py_END_ALLOW_THREADS
 
        /* Actually negotiate SSL connection */
        /* XXX If SSL_connect() returns 0, it's also a failure. */
-       ret = SSL_connect(self->ssl);
-       Py_END_ALLOW_THREADS
+       timedout = 0;
+       do {
+               Py_BEGIN_ALLOW_THREADS
+               ret = SSL_connect(self->ssl);
+               err = SSL_get_error(self->ssl, ret);
+               Py_END_ALLOW_THREADS
+               if (err == SSL_ERROR_WANT_READ) {
+                       timedout = wait_for_timeout(Sock, 0);
+               } else if (err == SSL_ERROR_WANT_WRITE) {
+                       timedout = wait_for_timeout(Sock, 1);
+               }
+               if (timedout) {
+                       PyErr_SetString(PySSLErrorObject, "The connect operation timed out");
+                       return NULL;
+               }
+       } while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE);
        if (ret <= 0) {
                PySSL_SetError(self, ret);
                goto fail;
@@ -328,10 +354,12 @@ wait_for_timeout(PySocketSockObject *s, int writing)
        FD_SET(s->sock_fd, &fds);
 
        /* See if the socket is ready */
+       Py_BEGIN_ALLOW_THREADS
        if (writing)
                rc = select(s->sock_fd+1, NULL, &fds, NULL, &tv);
        else
                rc = select(s->sock_fd+1, &fds, NULL, NULL, &tv);
+       Py_END_ALLOW_THREADS
 
        /* Return 1 on timeout, 0 otherwise */
        return rc == 0;
@@ -342,20 +370,32 @@ static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args)
        char *data;
        int len;
        int timedout;
+       int err;
 
        if (!PyArg_ParseTuple(args, "s#:write", &data, &len))
                return NULL;
 
-       Py_BEGIN_ALLOW_THREADS
        timedout = wait_for_timeout(self->Socket, 1);
-       Py_END_ALLOW_THREADS
        if (timedout) {
                PyErr_SetString(PySSLErrorObject, "The write operation timed out");
                return NULL;
        }
-       Py_BEGIN_ALLOW_THREADS
-       len = SSL_write(self->ssl, data, len);
-       Py_END_ALLOW_THREADS
+       do {
+               err = 0;
+               Py_BEGIN_ALLOW_THREADS
+               len = SSL_write(self->ssl, data, len);
+               err = SSL_get_error(self->ssl, len);
+               Py_END_ALLOW_THREADS
+               if (err == SSL_ERROR_WANT_READ) {
+                       timedout = wait_for_timeout(self->Socket, 0);
+               } else if (err == SSL_ERROR_WANT_WRITE) {
+                       timedout = wait_for_timeout(self->Socket, 1);
+               }
+               if (timedout) {
+                       PyErr_SetString(PySSLErrorObject, "The write operation timed out");
+                       return NULL;
+               }
+       } while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE);
        if (len > 0)
                return PyInt_FromLong(len);
        else
@@ -374,6 +414,7 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
        int count = 0;
        int len = 1024;
        int timedout;
+       int err;
 
        if (!PyArg_ParseTuple(args, "|i:read", &len))
                return NULL;
@@ -381,16 +422,27 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
        if (!(buf = PyString_FromStringAndSize((char *) 0, len)))
                return NULL;
 
-       Py_BEGIN_ALLOW_THREADS
        timedout = wait_for_timeout(self->Socket, 0);
-       Py_END_ALLOW_THREADS
        if (timedout) {
                PyErr_SetString(PySSLErrorObject, "The read operation timed out");
                return NULL;
        }
-       Py_BEGIN_ALLOW_THREADS
-       count = SSL_read(self->ssl, PyString_AsString(buf), len);
-       Py_END_ALLOW_THREADS
+       do {
+               err = 0;
+               Py_BEGIN_ALLOW_THREADS
+               count = SSL_read(self->ssl, PyString_AsString(buf), len);
+               err = SSL_get_error(self->ssl, count);
+               Py_END_ALLOW_THREADS
+               if (err == SSL_ERROR_WANT_READ) {
+                       timedout = wait_for_timeout(self->Socket, 0);
+               } else if (err == SSL_ERROR_WANT_WRITE) {
+                       timedout = wait_for_timeout(self->Socket, 1);
+               }
+               if (timedout) {
+                       PyErr_SetString(PySSLErrorObject, "The read operation timed out");
+                       return NULL;
+               }
+       } while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE);
        if (count <= 0) {
                Py_DECREF(buf);
                return PySSL_SetError(self, count);