]> git.ipfire.org Git - thirdparty/curl.git/commitdiff
openssl-quic: do not iterate over multi handles
authorStefan Eissing <stefan@eissing.org>
Fri, 7 Mar 2025 12:09:16 +0000 (13:09 +0100)
committerDaniel Stenberg <daniel@haxx.se>
Fri, 7 Mar 2025 13:54:25 +0000 (14:54 +0100)
Iterate over the filters stream hash instead, lookup easy handles
at the multi when needed.

This also limits to pollset array sizes to the number of streams
on the connection and not the total number of transfers in the multi.

Closes #16611

lib/vquic/curl_osslq.c
tests/http/test_14_auth.py

index a69da56803d15b6b29a11c85713fc643507d8c4d..c221124d723bdfd00ad6eb38424c5a31f34c5860 100644 (file)
@@ -290,7 +290,7 @@ struct cf_osslq_ctx {
   uint64_t max_idle_ms;              /* max idle time for QUIC connection */
   SSL_POLL_ITEM *poll_items;         /* Array for polling on writable state */
   struct Curl_easy **curl_items;     /* Array of easy objs */
-  size_t item_count;                 /* count of elements in poll/curl_items */
+  size_t items_max;                  /* max elements in poll/curl_items */
   BIT(initialized);
   BIT(got_first_byte);               /* if first byte was received */
   BIT(x509_store_setup);             /* if x509 store has been set up */
@@ -309,7 +309,7 @@ static void cf_osslq_ctx_init(struct cf_osslq_ctx *ctx)
   Curl_hash_offt_init(&ctx->streams, 63, h3_stream_hash_free);
   ctx->poll_items = NULL;
   ctx->curl_items = NULL;
-  ctx->item_count = 0;
+  ctx->items_max = 0;
   ctx->initialized = TRUE;
 }
 
@@ -666,6 +666,24 @@ static void h3_data_done(struct Curl_cfilter *cf, struct Curl_easy *data)
   }
 }
 
+struct cf_ossq_find_ctx {
+  curl_int64_t stream_id;
+  struct h3_stream_ctx *stream;
+};
+
+static bool cf_osslq_find_stream(curl_off_t mid, void *val, void *user_data)
+{
+  struct h3_stream_ctx *stream = val;
+  struct cf_ossq_find_ctx *fctx = user_data;
+
+  (void)mid;
+  if(stream && stream->s.id == fctx->stream_id) {
+    fctx->stream = stream;
+    return FALSE; /* stop iterating */
+  }
+  return TRUE;
+}
+
 static struct cf_osslq_stream *cf_osslq_get_qstream(struct Curl_cfilter *cf,
                                                     struct Curl_easy *data,
                                                     int64_t stream_id)
@@ -686,17 +704,12 @@ static struct cf_osslq_stream *cf_osslq_get_qstream(struct Curl_cfilter *cf,
     return &ctx->h3.s_qpack_dec;
   }
   else {
-    struct Curl_llist_node *e;
-    DEBUGASSERT(data->multi);
-    for(e = Curl_llist_head(&data->multi->process); e; e = Curl_node_next(e)) {
-      struct Curl_easy *sdata = Curl_node_elem(e);
-      if(sdata->conn != data->conn)
-        continue;
-      stream = H3_STREAM_CTX(ctx, sdata);
-      if(stream && stream->s.id == stream_id) {
-        return &stream->s;
-      }
-    }
+    struct cf_ossq_find_ctx fctx;
+    fctx.stream_id = stream_id;
+    fctx.stream = NULL;
+    Curl_hash_offt_visit(&ctx->streams, cf_osslq_find_stream, &fctx);
+    if(fctx.stream)
+      return &fctx.stream->s;
   }
   return NULL;
 }
@@ -1401,6 +1414,29 @@ out:
   return result;
 }
 
+struct cf_ossq_recv_ctx {
+  struct Curl_cfilter *cf;
+  struct Curl_multi *multi;
+  CURLcode result;
+};
+
+static bool cf_osslq_iter_recv(curl_off_t mid, void *val, void *user_data)
+{
+  struct h3_stream_ctx *stream = val;
+  struct cf_ossq_recv_ctx *rctx = user_data;
+
+  (void)mid;
+  if(stream && !stream->closed && !Curl_bufq_is_full(&stream->recvbuf)) {
+    struct Curl_easy *sdata = Curl_multi_get_handle(rctx->multi, mid);
+    if(sdata) {
+      rctx->result = cf_osslq_stream_recv(&stream->s, rctx->cf, sdata);
+      if(rctx->result)
+        return FALSE; /* abort iteration */
+    }
+  }
+  return TRUE;
+}
+
 static CURLcode cf_progress_ingress(struct Curl_cfilter *cf,
                                     struct Curl_easy *data)
 {
@@ -1437,22 +1473,14 @@ static CURLcode cf_progress_ingress(struct Curl_cfilter *cf,
   }
 
   if(ctx->h3.conn) {
-    struct Curl_llist_node *e;
-    struct h3_stream_ctx *stream;
-    /* PULL all open streams */
+    struct cf_ossq_recv_ctx rctx;
+
     DEBUGASSERT(data->multi);
-    for(e = Curl_llist_head(&data->multi->process); e; e = Curl_node_next(e)) {
-      struct Curl_easy *sdata = Curl_node_elem(e);
-      if(sdata->conn == data->conn && CURL_WANT_RECV(sdata)) {
-        stream = H3_STREAM_CTX(ctx, sdata);
-        if(stream && !stream->closed &&
-           !Curl_bufq_is_full(&stream->recvbuf)) {
-          result = cf_osslq_stream_recv(&stream->s, cf, sdata);
-          if(result)
-            goto out;
-        }
-      }
-    }
+    rctx.cf = cf;
+    rctx.multi = data->multi;
+    rctx.result = CURLE_OK;
+    Curl_hash_offt_visit(&ctx->streams, cf_osslq_iter_recv, &rctx);
+    result = rctx.result;
   }
 
 out:
@@ -1460,13 +1488,43 @@ out:
   return result;
 }
 
+struct cf_ossq_fill_ctx {
+  struct cf_osslq_ctx *ctx;
+  struct Curl_multi *multi;
+  size_t n;
+};
+
+static bool cf_osslq_collect_block_send(curl_off_t mid, void *val,
+                                        void *user_data)
+{
+  struct h3_stream_ctx *stream = val;
+  struct cf_ossq_fill_ctx *fctx = user_data;
+  struct cf_osslq_ctx *ctx = fctx->ctx;
+
+  if(fctx->n >= ctx->items_max)  /* should not happen, prevent mayhem */
+    return FALSE;
+
+  if(stream && stream->s.ssl && stream->s.send_blocked) {
+    struct Curl_easy *sdata = Curl_multi_get_handle(fctx->multi, mid);
+    fprintf(stderr, "[OSSLQ] stream %" FMT_PRId64 " sdata=%p\n",
+            stream->s.id, (void *)sdata);
+    if(sdata) {
+      ctx->poll_items[fctx->n].desc = SSL_as_poll_descriptor(stream->s.ssl);
+      ctx->poll_items[fctx->n].events = SSL_POLL_EVENT_W;
+      ctx->curl_items[fctx->n] = sdata;
+      fctx->n++;
+    }
+  }
+  return TRUE;
+}
+
 /* Iterate over all streams and check if blocked can be unblocked */
 static CURLcode cf_osslq_check_and_unblock(struct Curl_cfilter *cf,
                                            struct Curl_easy *data)
 {
   struct cf_osslq_ctx *ctx = cf->ctx;
   struct h3_stream_ctx *stream;
-  size_t poll_count = 0;
+  size_t poll_count;
   size_t result_count = 0;
   size_t idx_count = 0;
   CURLcode res = CURLE_OK;
@@ -1474,66 +1532,58 @@ static CURLcode cf_osslq_check_and_unblock(struct Curl_cfilter *cf,
   void *tmpptr;
 
   if(ctx->h3.conn) {
-    struct Curl_llist_node *e;
-
-    res = CURLE_OUT_OF_MEMORY;
+    struct cf_ossq_fill_ctx fill_ctx;
 
-    if(ctx->item_count < Curl_llist_count(&data->multi->process)) {
-      ctx->item_count = 0;
-      tmpptr = realloc(ctx->poll_items,
-                       Curl_llist_count(&data->multi->process) *
-                       sizeof(SSL_POLL_ITEM));
+    if(ctx->items_max < Curl_hash_offt_count(&ctx->streams)) {
+      size_t nmax = Curl_hash_offt_count(&ctx->streams);
+      ctx->items_max = 0;
+      tmpptr = realloc(ctx->poll_items, nmax * sizeof(SSL_POLL_ITEM));
       if(!tmpptr) {
         free(ctx->poll_items);
         ctx->poll_items = NULL;
+        res = CURLE_OUT_OF_MEMORY;
         goto out;
       }
       ctx->poll_items = tmpptr;
 
-      tmpptr = realloc(ctx->curl_items,
-                       Curl_llist_count(&data->multi->process) *
-                       sizeof(struct Curl_easy *));
+      tmpptr = realloc(ctx->curl_items, nmax * sizeof(struct Curl_easy *));
       if(!tmpptr) {
         free(ctx->curl_items);
         ctx->curl_items = NULL;
+        res = CURLE_OUT_OF_MEMORY;
         goto out;
       }
       ctx->curl_items = tmpptr;
-
-      ctx->item_count = Curl_llist_count(&data->multi->process);
-    }
-
-    for(e = Curl_llist_head(&data->multi->process); e; e = Curl_node_next(e)) {
-      struct Curl_easy *sdata = Curl_node_elem(e);
-      if(sdata->conn == data->conn) {
-        stream = H3_STREAM_CTX(ctx, sdata);
-        if(stream && stream->s.ssl && stream->s.send_blocked) {
-          ctx->poll_items[poll_count].desc =
-            SSL_as_poll_descriptor(stream->s.ssl);
-          ctx->poll_items[poll_count].events = SSL_POLL_EVENT_W;
-          ctx->curl_items[poll_count] = sdata;
-          poll_count++;
-        }
-      }
+      ctx->items_max = nmax;
     }
 
-    memset(&timeout, 0, sizeof(struct timeval));
-    res = CURLE_UNRECOVERABLE_POLL;
-    if(!SSL_poll(ctx->poll_items, poll_count, sizeof(SSL_POLL_ITEM), &timeout,
-                 0, &result_count))
-        goto out;
+    fill_ctx.ctx = ctx;
+    fill_ctx.multi = data->multi;
+    fill_ctx.n = 0;
+    Curl_hash_offt_visit(&ctx->streams, cf_osslq_collect_block_send,
+                         &fill_ctx);
+    poll_count = fill_ctx.n;
+    if(poll_count) {
+      CURL_TRC_CF(data, cf, "polling %zu blocked streams", poll_count);
+
+      memset(&timeout, 0, sizeof(struct timeval));
+      res = CURLE_UNRECOVERABLE_POLL;
+      if(!SSL_poll(ctx->poll_items, poll_count, sizeof(SSL_POLL_ITEM),
+                   &timeout, 0, &result_count))
+          goto out;
 
-    res = CURLE_OK;
-
-    for(idx_count = 0; idx_count < poll_count && result_count > 0;
-        idx_count++) {
-      if(ctx->poll_items[idx_count].revents & SSL_POLL_EVENT_W) {
-        stream = H3_STREAM_CTX(ctx, ctx->curl_items[idx_count]);
-        nghttp3_conn_unblock_stream(ctx->h3.conn, stream->s.id);
-        stream->s.send_blocked = FALSE;
-        h3_drain_stream(cf, ctx->curl_items[idx_count]);
-        CURL_TRC_CF(ctx->curl_items[idx_count], cf, "unblocked");
-        result_count--;
+      res = CURLE_OK;
+
+      for(idx_count = 0; idx_count < poll_count && result_count > 0;
+          idx_count++) {
+        if(ctx->poll_items[idx_count].revents & SSL_POLL_EVENT_W) {
+          stream = H3_STREAM_CTX(ctx, ctx->curl_items[idx_count]);
+          nghttp3_conn_unblock_stream(ctx->h3.conn, stream->s.id);
+          stream->s.send_blocked = FALSE;
+          h3_drain_stream(cf, ctx->curl_items[idx_count]);
+          CURL_TRC_CF(ctx->curl_items[idx_count], cf, "unblocked");
+          result_count--;
+        }
       }
     }
   }
index 237d7ecda8f32fbdd0b95dc8df48ccaa72d698d8..13193b53b8a24c65f7fcd8bd237bfeca3fea3c18 100644 (file)
@@ -134,4 +134,4 @@ class TestAuth:
         # Depending on protocol, we might have an error sending or
         # the server might shutdown the connection and we see the error
         # on receiving
-        assert r.exit_code in [55, 56], f'{r.dump_logs()}'
+        assert r.exit_code in [55, 56, 95], f'{r.dump_logs()}'