]> git.ipfire.org Git - thirdparty/openssl.git/commitdiff
Add locking to QUIC front-end
authorHugo Landau <hlandau@openssl.org>
Tue, 21 Feb 2023 10:18:58 +0000 (10:18 +0000)
committerHugo Landau <hlandau@openssl.org>
Thu, 30 Mar 2023 10:14:07 +0000 (11:14 +0100)
Reviewed-by: Tomas Mraz <tomas@openssl.org>
Reviewed-by: Matt Caswell <matt@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/20348)

include/internal/quic_channel.h
ssl/quic/quic_impl.c

index cd51fe30a449672b12d63071c5929b3995978e48..a279bfb5dd4eb96c2f069819ffbab0e4cf862a86 100644 (file)
  * mutex which then serves as the channel mutex; see QUIC_CHANNEL_ARGS.
  */
 
+/*
+ * The function does not acquire the channel mutex and assumes it is already
+ * held by the calling thread.
+ *
+ * Any function tagged with this has the following precondition:
+ *
+ *   Precondition: must hold channel mutex (unchecked)
+ */
 #  define QUIC_NEEDS_LOCK
+
+/*
+ * The function acquires the channel mutex and releases it before returning in
+ * all circumstances.
+ *
+ * Any function tagged with this has the following precondition and
+ * postcondition:
+ *
+ *   Precondition: must not hold channel mutex (unchecked)
+ *   Postcondition: channel mutex is not held (by calling thread)
+ *
+ */
 #  define QUIC_TAKES_LOCK
+
 #  define QUIC_TODO_LOCK
 
 #  define QUIC_CHANNEL_STATE_IDLE                        0
index c17b5354a2b0da60506764993853150992208068..2773e8784e914b0be8d3913da0a7bdc38d5e0f29 100644 (file)
@@ -175,7 +175,7 @@ err:
 }
 
 /* SSL_free */
-QUIC_TODO_LOCK
+QUIC_TAKES_LOCK
 void ossl_quic_free(SSL *s)
 {
     QUIC_CONNECTION *qc = QUIC_CONNECTION_FROM_SSL(s);
@@ -184,6 +184,7 @@ void ossl_quic_free(SSL *s)
     if (!expect_quic_conn(qc))
         return;
 
+    quic_lock(qc); /* best effort */
     ossl_quic_channel_free(qc->ch);
 
     BIO_free(qc->net_rbio);
@@ -431,13 +432,19 @@ static int blocking_mode(const QUIC_CONNECTION *qc)
 }
 
 /* SSL_tick; ticks the reactor. */
-QUIC_TODO_LOCK
+QUIC_TAKES_LOCK
 int ossl_quic_tick(QUIC_CONNECTION *qc)
 {
-    if (qc->ch == NULL)
+    if (!quic_lock(qc))
+        return 0;
+
+    if (qc->ch == NULL) {
+        quic_unlock(qc);
         return 1;
+    }
 
     ossl_quic_reactor_tick(ossl_quic_channel_get_reactor(qc->ch));
+    quic_unlock(qc);
     return 1;
 }
 
@@ -447,11 +454,14 @@ int ossl_quic_tick(QUIC_CONNECTION *qc)
  * the object should be ticked immediately and tv->tv_sec is set to -1 if no
  * timeout is currently active.
  */
-QUIC_TODO_LOCK
+QUIC_TAKES_LOCK
 int ossl_quic_get_tick_timeout(QUIC_CONNECTION *qc, struct timeval *tv)
 {
     OSSL_TIME deadline = ossl_time_infinite();
 
+    if (!quic_lock(qc))
+        return 0;
+
     if (qc->ch != NULL)
         deadline
             = ossl_quic_reactor_get_tick_deadline(ossl_quic_channel_get_reactor(qc->ch));
@@ -459,10 +469,12 @@ int ossl_quic_get_tick_timeout(QUIC_CONNECTION *qc, struct timeval *tv)
     if (ossl_time_is_infinite(deadline)) {
         tv->tv_sec  = -1;
         tv->tv_usec = 0;
+        quic_unlock(qc);
         return 1;
     }
 
     *tv = ossl_time_to_timeval(ossl_time_subtract(deadline, ossl_time_now()));
+    quic_unlock(qc);
     return 1;
 }
 
@@ -485,23 +497,37 @@ int ossl_quic_get_wpoll_descriptor(QUIC_CONNECTION *qc, BIO_POLL_DESCRIPTOR *des
 }
 
 /* SSL_net_read_desired */
-QUIC_TODO_LOCK
+QUIC_TAKES_LOCK
 int ossl_quic_get_net_read_desired(QUIC_CONNECTION *qc)
 {
+    int ret;
+
+    if (!quic_lock(qc))
+        return 0;
+
     if (qc->ch == NULL)
         return 0;
 
-    return ossl_quic_reactor_net_read_desired(ossl_quic_channel_get_reactor(qc->ch));
+    ret = ossl_quic_reactor_net_read_desired(ossl_quic_channel_get_reactor(qc->ch));
+    quic_unlock(qc);
+    return ret;
 }
 
 /* SSL_net_write_desired */
-QUIC_TODO_LOCK
+QUIC_TAKES_LOCK
 int ossl_quic_get_net_write_desired(QUIC_CONNECTION *qc)
 {
+    int ret;
+
+    if (!quic_lock(qc))
+        return 0;
+
     if (qc->ch == NULL)
         return 0;
 
-    return ossl_quic_reactor_net_write_desired(ossl_quic_channel_get_reactor(qc->ch));
+    ret = ossl_quic_reactor_net_write_desired(ossl_quic_channel_get_reactor(qc->ch));
+    quic_unlock(qc);
+    return ret;
 }
 
 /*
@@ -526,28 +552,39 @@ static int quic_shutdown_wait(void *arg)
     return qc->ch == NULL || ossl_quic_channel_is_terminated(qc->ch);
 }
 
-QUIC_TODO_LOCK
+QUIC_TAKES_LOCK
 int ossl_quic_conn_shutdown(QUIC_CONNECTION *qc, uint64_t flags,
                             const SSL_SHUTDOWN_EX_ARGS *args,
                             size_t args_len)
 {
-    if (!ensure_channel(qc))
+    int ret;
+
+    if (!quic_lock(qc))
+        return -1;
+
+    if (!ensure_channel(qc)) {
+        quic_unlock(qc);
         return -1;
+    }
 
     ossl_quic_channel_local_close(qc->ch,
                                   args != NULL ? args->quic_error_code : 0);
 
     /* TODO(QUIC): !SSL_SHUTDOWN_FLAG_NO_STREAM_FLUSH */
 
-    if (ossl_quic_channel_is_terminated(qc->ch))
+    if (ossl_quic_channel_is_terminated(qc->ch)) {
+        quic_unlock(qc);
         return 1;
+    }
 
     if (blocking_mode(qc) && (flags & SSL_SHUTDOWN_FLAG_RAPID) == 0)
         block_until_pred(qc, quic_shutdown_wait, qc, 0);
     else
         ossl_quic_reactor_tick(ossl_quic_channel_get_reactor(qc->ch));
 
-    return ossl_quic_channel_is_terminated(qc->ch);
+    ret = ossl_quic_channel_is_terminated(qc->ch);
+    quic_unlock(qc);
+    return ret;
 }
 
 /* SSL_ctrl */
@@ -674,34 +711,44 @@ static int ensure_channel_and_start(QUIC_CONNECTION *qc)
     return 1;
 }
 
-QUIC_TODO_LOCK
+QUIC_TAKES_LOCK
 int ossl_quic_do_handshake(QUIC_CONNECTION *qc)
 {
     int ret;
 
-    if (qc->ch != NULL && ossl_quic_channel_is_handshake_complete(qc->ch))
+    if (!quic_lock(qc))
+        return -1;
+
+    if (qc->ch != NULL && ossl_quic_channel_is_handshake_complete(qc->ch)) {
         /* Handshake already completed. */
-        return 1;
+        ret = 1;
+        goto out;
+    }
 
-    if (qc->ch != NULL && ossl_quic_channel_is_term_any(qc->ch))
-        return QUIC_RAISE_NON_NORMAL_ERROR(qc, SSL_R_PROTOCOL_IS_SHUTDOWN, NULL);
+    if (qc->ch != NULL && ossl_quic_channel_is_term_any(qc->ch)) {
+        ret = QUIC_RAISE_NON_NORMAL_ERROR(qc, SSL_R_PROTOCOL_IS_SHUTDOWN, NULL);
+        goto out;
+    }
 
     if (BIO_ADDR_family(&qc->init_peer_addr) == AF_UNSPEC) {
         /* Peer address must have been set. */
         QUIC_RAISE_NON_NORMAL_ERROR(qc, SSL_R_REMOTE_PEER_ADDRESS_NOT_SET, NULL);
-        return -1; /* Non-protocol error */
+        ret = -1; /* Non-protocol error */
+        goto out;
     }
 
     if (qc->as_server) {
         /* TODO(QUIC): Server mode not currently supported */
         QUIC_RAISE_NON_NORMAL_ERROR(qc, ERR_R_PASSED_INVALID_ARGUMENT, NULL);
-        return -1; /* Non-protocol error */
+        ret = -1;
+        goto out; /* Non-protocol error */
     }
 
     if (qc->net_rbio == NULL || qc->net_wbio == NULL) {
         /* Need read and write BIOs. */
         QUIC_RAISE_NON_NORMAL_ERROR(qc, SSL_R_BIO_NOT_SET, NULL);
-        return -1; /* Non-protocol error */
+        ret = -1;
+        goto out; /* Non-protocol error */
     }
 
     /*
@@ -710,12 +757,15 @@ int ossl_quic_do_handshake(QUIC_CONNECTION *qc)
      */
     if (!ensure_channel_and_start(qc)) {
         QUIC_RAISE_NON_NORMAL_ERROR(qc, ERR_R_INTERNAL_ERROR, NULL);
-        return -1; /* Non-protocol error */
+        ret = -1;
+        goto out; /* Non-protocol error */
     }
 
-    if (ossl_quic_channel_is_handshake_complete(qc->ch))
+    if (ossl_quic_channel_is_handshake_complete(qc->ch)) {
         /* The handshake is now done. */
-        return 1;
+        ret = 1;
+        goto out;
+    }
 
     if (blocking_mode(qc)) {
         /* In blocking mode, wait for the handshake to complete. */
@@ -726,26 +776,36 @@ int ossl_quic_do_handshake(QUIC_CONNECTION *qc)
         ret = block_until_pred(qc, quic_handshake_wait, &args, 0);
         if (!ossl_quic_channel_is_active(qc->ch)) {
             QUIC_RAISE_NON_NORMAL_ERROR(qc, SSL_R_PROTOCOL_IS_SHUTDOWN, NULL);
-            return 0; /* Shutdown before completion */
+            ret = 0;
+            goto out; /* Shutdown before completion */
         } else if (ret <= 0) {
             QUIC_RAISE_NON_NORMAL_ERROR(qc, ERR_R_INTERNAL_ERROR, NULL);
-            return -1; /* Non-protocol error */
+            ret = -1;
+            goto out; /* Non-protocol error */
         }
 
         assert(ossl_quic_channel_is_handshake_complete(qc->ch));
-        return 1;
+        ret = 1;
+        goto out;
     } else {
         /* Try to advance the reactor. */
         ossl_quic_reactor_tick(ossl_quic_channel_get_reactor(qc->ch));
 
-        if (ossl_quic_channel_is_handshake_complete(qc->ch))
+        if (ossl_quic_channel_is_handshake_complete(qc->ch)) {
             /* The handshake is now done. */
-            return 1;
+            ret = 1;
+            goto out;
+        }
 
         /* Otherwise, indicate that the handshake isn't done yet. */
         QUIC_RAISE_NORMAL_ERROR(qc, SSL_ERROR_WANT_READ);
-        return -1; /* Non-protocol error */
+        ret = -1;
+        goto out; /* Non-protocol error */
     }
+
+out:
+    quic_unlock(qc);
+    return ret;
 }
 
 /* SSL_connect */
@@ -1044,9 +1104,10 @@ static int quic_write_nonblocking_epw(QUIC_CONNECTION *qc, const void *buf, size
     return 1;
 }
 
-QUIC_TODO_LOCK
+QUIC_TAKES_LOCK
 int ossl_quic_write(SSL *s, const void *buf, size_t len, size_t *written)
 {
+    int ret;
     QUIC_CONNECTION *qc = QUIC_CONNECTION_FROM_SSL(s);
     int partial_write = ((qc->ssl_mode & SSL_MODE_ENABLE_PARTIAL_WRITE) != 0);
 
@@ -1055,25 +1116,38 @@ int ossl_quic_write(SSL *s, const void *buf, size_t len, size_t *written)
     if (!expect_quic_conn(qc))
         return 0;
 
-    if (qc->ch != NULL && ossl_quic_channel_is_term_any(qc->ch))
-        return QUIC_RAISE_NON_NORMAL_ERROR(qc, SSL_R_PROTOCOL_IS_SHUTDOWN, NULL);
+    if (!quic_lock(qc))
+        return 0;
+
+    if (qc->ch != NULL && ossl_quic_channel_is_term_any(qc->ch)) {
+        ret = QUIC_RAISE_NON_NORMAL_ERROR(qc, SSL_R_PROTOCOL_IS_SHUTDOWN, NULL);
+        goto out;
+    }
 
     /*
      * If we haven't finished the handshake, try to advance it.
      * We don't accept writes until the handshake is completed.
      */
-    if (ossl_quic_do_handshake(qc) < 1)
-        return 0;
+    if (ossl_quic_do_handshake(qc) < 1) {
+        ret = 0;
+        goto out;
+    }
 
-    if (qc->stream0 == NULL || qc->stream0->sstream == NULL)
-        return QUIC_RAISE_NON_NORMAL_ERROR(qc, ERR_R_INTERNAL_ERROR, NULL);
+    if (qc->stream0 == NULL || qc->stream0->sstream == NULL) {
+        ret = QUIC_RAISE_NON_NORMAL_ERROR(qc, ERR_R_INTERNAL_ERROR, NULL);
+        goto out;
+    }
 
     if (blocking_mode(qc))
-        return quic_write_blocking(qc, buf, len, written);
+        ret = quic_write_blocking(qc, buf, len, written);
     else if (partial_write)
-        return quic_write_nonblocking_epw(qc, buf, len, written);
+        ret = quic_write_nonblocking_epw(qc, buf, len, written);
     else
-        return quic_write_nonblocking_aon(qc, buf, len, written);
+        ret = quic_write_nonblocking_aon(qc, buf, len, written);
+
+out:
+    quic_unlock(qc);
+    return ret;
 }
 
 /*
@@ -1167,10 +1241,10 @@ static int quic_read_again(void *arg)
     return 0; /* did not read anything, keep trying */
 }
 
-QUIC_TODO_LOCK
+QUIC_TAKES_LOCK
 static int quic_read(SSL *s, void *buf, size_t len, size_t *bytes_read, int peek)
 {
-    int res;
+    int ret, res;
     QUIC_CONNECTION *qc = QUIC_CONNECTION_FROM_SSL(s);
     struct quic_read_again_args args;
 
@@ -1179,18 +1253,29 @@ static int quic_read(SSL *s, void *buf, size_t len, size_t *bytes_read, int peek
     if (!expect_quic_conn(qc))
         return 0;
 
-    if (qc->ch != NULL && ossl_quic_channel_is_term_any(qc->ch))
-        return QUIC_RAISE_NON_NORMAL_ERROR(qc, SSL_R_PROTOCOL_IS_SHUTDOWN, NULL);
+    if (!quic_lock(qc))
+        return 0;
+
+    if (qc->ch != NULL && ossl_quic_channel_is_term_any(qc->ch)) {
+        ret = QUIC_RAISE_NON_NORMAL_ERROR(qc, SSL_R_PROTOCOL_IS_SHUTDOWN, NULL);
+        goto out;
+    }
 
     /* If we haven't finished the handshake, try to advance it. */
-    if (ossl_quic_do_handshake(qc) < 1)
-        return 0; /* ossl_quic_do_handshake raised error here */
+    if (ossl_quic_do_handshake(qc) < 1) {
+        ret = 0; /* ossl_quic_do_handshake raised error here */
+        goto out;
+    }
 
-    if (qc->stream0 == NULL)
-        return QUIC_RAISE_NON_NORMAL_ERROR(qc, ERR_R_INTERNAL_ERROR, NULL);
+    if (qc->stream0 == NULL) {
+        ret = QUIC_RAISE_NON_NORMAL_ERROR(qc, ERR_R_INTERNAL_ERROR, NULL);
+        goto out;
+    }
 
-    if (!quic_read_actual(qc, qc->stream0, buf, len, bytes_read, peek))
-        return 0; /* quic_read_actual raised error here */
+    if (!quic_read_actual(qc, qc->stream0, buf, len, bytes_read, peek)) {
+        ret = 0; /* quic_read_actual raised error here */
+        goto out;
+    }
 
     if (*bytes_read > 0) {
         /*
@@ -1198,7 +1283,7 @@ static int quic_read(SSL *s, void *buf, size_t len, size_t *bytes_read, int peek
          * handling other aspects of the QUIC connection.
          */
         ossl_quic_reactor_tick(ossl_quic_channel_get_reactor(qc->ch));
-        return 1;
+        ret = 1;
     } else if (blocking_mode(qc)) {
         /*
          * We were not able to read anything immediately, so our stream
@@ -1213,16 +1298,23 @@ static int quic_read(SSL *s, void *buf, size_t len, size_t *bytes_read, int peek
         args.peek       = peek;
 
         res = block_until_pred(qc, quic_read_again, &args, 0);
-        if (res == 0)
-            return QUIC_RAISE_NON_NORMAL_ERROR(qc, ERR_R_INTERNAL_ERROR, NULL);
-        else if (res < 0)
-            return 0; /* quic_read_again raised error here */
+        if (res == 0) {
+            ret = QUIC_RAISE_NON_NORMAL_ERROR(qc, ERR_R_INTERNAL_ERROR, NULL);
+            goto out;
+        } else if (res < 0) {
+            ret = 0; /* quic_read_again raised error here */
+            goto out;
+        }
 
-        return 1;
+        ret = 1;
     } else {
         /* We did not get any bytes and are not in blocking mode. */
-        return QUIC_RAISE_NORMAL_ERROR(qc, SSL_ERROR_WANT_READ);
+        ret = QUIC_RAISE_NORMAL_ERROR(qc, SSL_ERROR_WANT_READ);
     }
+
+out:
+    quic_unlock(qc);
+    return ret;
 }
 
 int ossl_quic_read(SSL *s, void *buf, size_t len, size_t *bytes_read)
@@ -1239,7 +1331,7 @@ int ossl_quic_peek(SSL *s, void *buf, size_t len, size_t *bytes_read)
  * SSL_pending
  * -----------
  */
-QUIC_TODO_LOCK
+QUIC_TAKES_LOCK
 static size_t ossl_quic_pending_int(const QUIC_CONNECTION *qc)
 {
     size_t avail = 0;
@@ -1248,13 +1340,18 @@ static size_t ossl_quic_pending_int(const QUIC_CONNECTION *qc)
     if (!expect_quic_conn(qc))
         return 0;
 
+    if (!quic_lock((QUIC_CONNECTION *)qc))
+        return 0;
+
     if (qc->stream0 == NULL || qc->stream0->rstream == NULL)
         /* Cannot raise errors here because we are const, just fail. */
-        return 0;
+        goto out;
 
     if (!ossl_quic_rstream_available(qc->stream0->rstream, &avail, &fin))
-        return 0;
+        avail = 0;
 
+out:
+    quic_unlock((QUIC_CONNECTION *)qc);
     return avail;
 }
 
@@ -1274,20 +1371,28 @@ int ossl_quic_has_pending(const QUIC_CONNECTION *qc)
  * SSL_stream_conclude
  * -------------------
  */
-QUIC_TODO_LOCK
+QUIC_TAKES_LOCK
 int ossl_quic_conn_stream_conclude(QUIC_CONNECTION *qc)
 {
     QUIC_STREAM *qs = qc->stream0;
 
-    if (qs == NULL || qs->sstream == NULL)
+    if (!quic_lock(qc))
         return 0;
 
+    if (qs == NULL || qs->sstream == NULL) {
+        quic_unlock(qc);
+        return 0;
+    }
+
     if (!ossl_quic_channel_is_active(qc->ch)
-        || ossl_quic_sstream_get_final_size(qs->sstream, NULL))
+        || ossl_quic_sstream_get_final_size(qs->sstream, NULL)) {
+        quic_unlock(qc);
         return 1;
+    }
 
     ossl_quic_sstream_fin(qs->sstream);
     quic_post_write(qc, 1, 1);
+    quic_unlock(qc);
     return 1;
 }