]> git.ipfire.org Git - thirdparty/openssl.git/commitdiff
Arrange the code following the reviews.
authorJean-Frederic Clere <jfclere@gmail.com>
Thu, 5 Dec 2024 14:48:25 +0000 (15:48 +0100)
committerNeil Horman <nhorman@openssl.org>
Mon, 17 Feb 2025 16:27:33 +0000 (11:27 -0500)
Reviewed-by: Neil Horman <nhorman@openssl.org>
Reviewed-by: Saša Nedvědický <sashan@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/25859)

demos/http3/ossl-nghttp3-demo-server.c

index a1ef0b4f3d7092974f16d88855a70760d8d1e774..87b9cdc35a975c3ceadb9f3143eea6778598d114 100644 (file)
@@ -40,6 +40,8 @@ struct ssl_id {
 #define SERVERUNIOPEN  0x08 /* unidirectional open by the server (3, 7 and 11) */
 #define SERVERCLOSED   0x10 /* closed by the server (us) */
 #define TOBEREMOVED    0x20 /* marked for removing in read_from_ssl_ids, removed after processing all events */
+#define ISLISTENER     0x40 /* the stream is a listener from SSL_new_listener() */
+#define ISCONNECTION   0x80 /* the stream is a connection from SSL_accept_connection() */
 
 #define MAXSSL_IDS 20
 #define MAXURL 255
@@ -100,7 +102,7 @@ static void reuse_h3ssl(struct h3ssl *h3ssl)
     h3ssl->ldata = 0;
 }
 
-static void add_id(uint64_t id, SSL *ssl, struct h3ssl *h3ssl)
+static void add_id_status(uint64_t id, SSL *ssl, struct h3ssl *h3ssl, int status)
 {
     struct ssl_id *ssl_ids;
     int i;
@@ -110,25 +112,55 @@ static void add_id(uint64_t id, SSL *ssl, struct h3ssl *h3ssl)
         if (ssl_ids[i].s == NULL) {
             ssl_ids[i].s = ssl;
             ssl_ids[i].id = id;
+            ssl_ids[i].status = status;
             return;
         }
     }
     printf("Oops too many streams to add!!!\n");
     exit(1);
 }
-static void add_id_at(uint64_t id, SSL *ssl, int at, struct h3ssl *h3ssl)
+static void add_id(uint64_t id, SSL *ssl, struct h3ssl *h3ssl)
+{
+    add_id_status(id, ssl, h3ssl, 0);
+}
+
+/* Add listener and connection */
+static void add_ids_listener(SSL *ssl, struct h3ssl *h3ssl)
+{
+    add_id_status(UINT64_MAX, ssl, h3ssl, ISLISTENER);
+}
+static void add_ids_connection(struct h3ssl *h3ssl, SSL *ssl)
+{
+    add_id_status(UINT64_MAX, ssl, h3ssl, ISCONNECTION);
+}
+static SSL *get_ids_connection(struct h3ssl *h3ssl)
 {
     struct ssl_id *ssl_ids;
+    int i;
 
     ssl_ids = h3ssl->ssl_ids;
-    if (ssl_ids[at].s == NULL) {
-        ssl_ids[at].s = ssl;
-        ssl_ids[at].id = id;
-        return;
+    for (i = 0; i < MAXSSL_IDS; i++) {
+        if (ssl_ids[i].status & ISCONNECTION) {
+            printf("get_ids_connection\n");
+            return ssl_ids[i].s;
+        }
     }
-    printf("Oops %d already used\n", at);
-    exit(1);
+    return NULL;
 }
+static void replace_ids_connection(struct h3ssl *h3ssl, SSL *oldstream, SSL *newstream )
+{
+    struct ssl_id *ssl_ids;
+    int i;
+
+    ssl_ids = h3ssl->ssl_ids;
+    for (i = 0; i < MAXSSL_IDS; i++) {
+        if (ssl_ids[i].status & ISCONNECTION && ssl_ids[i].s == oldstream) {
+            printf("replace_ids_connection\n");
+            ssl_ids[i].s = newstream;
+        }
+    }
+}
+
 
 /* remove the ids marked for removal */
 static void remove_marked_ids(struct h3ssl *h3ssl)
@@ -163,7 +195,7 @@ static void set_id_status(uint64_t id, int status, struct h3ssl *h3ssl)
             return;
         }
     }
-    printf("Oops can't get status, can't find stream!!!\n");
+    printf("Oops can't set status, can't find stream!!!\n");
     assert(0);
 }
 static int get_id_status(uint64_t id, struct h3ssl *h3ssl)
@@ -179,7 +211,7 @@ static int get_id_status(uint64_t id, struct h3ssl *h3ssl)
             return ssl_ids[i].status;
         }
     }
-    printf("Oops can't set status, can't find stream!!!\n");
+    printf("Oops can't get status, can't find stream!!!\n");
     assert(0);
     return -1;
 }
@@ -193,10 +225,6 @@ static int are_all_clientid_closed(struct h3ssl *h3ssl)
     for (i = 0; i < MAXSSL_IDS; i++) {
         if (ssl_ids[i].id == UINT64_MAX)
             continue;
-        if (ssl_ids[i].status & CLIENTUNIOPEN) {
-            printf("are_all_clientid_closed: %llu open\n", (unsigned long long) ssl_ids[i].id);
-            return 0;
-        }
         printf("are_all_clientid_closed: %llu status %d : %d\n",
                (unsigned long long) ssl_ids[i].id, ssl_ids[i].status, CLIENTUNIOPEN | CLIENTCLOSED);
         if (ssl_ids[i].status & (CLIENTUNIOPEN | CLIENTCLOSED)) {
@@ -204,11 +232,17 @@ static int are_all_clientid_closed(struct h3ssl *h3ssl)
             SSL_free(ssl_ids[i].s);
             ssl_ids[i].s = NULL;
             ssl_ids[i].id = UINT64_MAX;
+            continue;
+        }
+        if (ssl_ids[i].status & CLIENTUNIOPEN) {
+            printf("are_all_clientid_closed: %llu open\n", (unsigned long long) ssl_ids[i].id);
+            return 0;
         }
     }
     return 1;
 }
 
+/* free all the ids except listener and connection */
 static void close_all_ids(struct h3ssl *h3ssl)
 {
     struct ssl_id *ssl_ids;
@@ -442,6 +476,11 @@ static int read_from_ssl_ids(nghttp3_conn **curh3conn, struct h3ssl *h3ssl)
         return 0;
     }
 
+    /* reset the states */
+    h3ssl->new_conn = 0;
+    h3ssl->restart = 0;
+    h3ssl->done = 0;
+
     /* Process all the item we have polled */
     item = NULL;
     for (i = 0; i < numitem; i++) {
@@ -460,6 +499,7 @@ static int read_from_ssl_ids(nghttp3_conn **curh3conn, struct h3ssl *h3ssl)
         /* New connection */
         if (item->revents & SSL_POLL_EVENT_IC) {
             SSL *conn = SSL_accept_connection(item->desc.value.ssl, 0);
+            SSL *oldconn;
 
             printf("SSL_accept_connection\n");
             if (conn == NULL) {
@@ -469,27 +509,32 @@ static int read_from_ssl_ids(nghttp3_conn **curh3conn, struct h3ssl *h3ssl)
             }
 
             /* the previous might be still there */
-            if (ssl_ids[1].s) {
+            oldconn = get_ids_connection(h3ssl);
+            if (oldconn != NULL) {
                 /* XXX we support only one connection for the moment */
                 printf("SSL_accept_connection closing previous\n");
-                SSL_free(h3ssl->ssl_ids[1].s);
-                h3ssl->ssl_ids[1].s = NULL;
+                SSL_free(oldconn);
+                replace_ids_connection(h3ssl, oldconn, conn);
                 reuse_h3ssl(h3ssl);
                 close_all_ids(h3ssl);
                 h3ssl->id_bidi = UINT64_MAX;
                 h3ssl->has_uni = 0;
-                /* we need a new h3conn here!!! */
-                if (nghttp3_conn_server_new(curh3conn, &callbacks, &settings, mem,
-                                            h3ssl)) {
-                    fprintf(stderr, "nghttp3_conn_client_new failed!\n");
-                    exit(1);
-                }
-                /* XXX : nghttp3_conn_del() for the old one? */
-                h3conn = *curh3conn;
-
-                hassomething++;
-                h3ssl->new_conn = 1;
+            } else {
+                printf("SSL_accept_connection first connection\n");
+                add_ids_connection(h3ssl, conn);
+            }
+            h3ssl->new_conn = 1;
+            /* create the new h3conn */
+            nghttp3_conn_del(*curh3conn);
+            nghttp3_settings_default(&settings);
+            if (nghttp3_conn_server_new(curh3conn, &callbacks, &settings, mem,
+                                        h3ssl)) {
+                fprintf(stderr, "nghttp3_conn_client_new failed!\n");
+                exit(1);
             }
+            h3conn = *curh3conn;
+            hassomething++;
+
             if (!SSL_set_incoming_stream_policy(conn,
                                                 SSL_INCOMING_STREAM_POLICY_ACCEPT, 0)) {
                 fprintf(stderr, "error while setting inccoming stream policy\n");
@@ -497,7 +542,6 @@ static int read_from_ssl_ids(nghttp3_conn **curh3conn, struct h3ssl *h3ssl)
                 goto err;
             }
 
-            add_id_at(-1, conn, 1, h3ssl);
             printf("SSL_accept_connection\n");
             processed_event = processed_event | SSL_POLL_EVENT_IC;
         }
@@ -667,7 +711,7 @@ static void handle_events_from_ids(struct h3ssl *h3ssl)
 
     ssl_ids = h3ssl->ssl_ids;
     for (i = 0; i < MAXSSL_IDS; i++) {
-        if (ssl_ids[i].s != NULL) {
+        if (ssl_ids[i].s != NULL && (ssl_ids[i].status & ISCONNECTION || ssl_ids[i].status & ISLISTENER)) {
             if (SSL_handle_events(ssl_ids[i].s))
                 ERR_print_errors_fp(stderr);
         }
@@ -944,6 +988,9 @@ static int run_quic_server(SSL_CTX *ctx, int fd)
     int ok = 0;
     int hassomething = 0;
     SSL *listener = NULL;
+    struct h3ssl h3ssl;
+    nghttp3_conn *h3conn = NULL;
+    SSL *ssl;
 
     /* Create a new QUIC listener. */
     if ((listener = SSL_new_listener(ctx, 0)) == NULL)
@@ -965,9 +1012,16 @@ static int run_quic_server(SSL_CTX *ctx, int fd)
     if (!SSL_set_blocking_mode(listener, 0))
         goto err;
 
+    /* Setup callbacks. */
+    callbacks.recv_header = on_recv_header;
+    callbacks.end_headers = on_end_headers;
+    callbacks.recv_data = on_recv_data;
+    callbacks.end_stream = on_end_stream;
+
+    /* mem default */
+    mem = nghttp3_mem_default();
+
     for (;;) {
-        nghttp3_conn *h3conn = NULL;
-        struct h3ssl h3ssl;
         nghttp3_nv resp[10];
         size_t num_nv;
         nghttp3_data_reader dr;
@@ -976,6 +1030,10 @@ static int run_quic_server(SSL_CTX *ctx, int fd)
         char slength[11];
         int hasnothing;
 
+        init_ids(&h3ssl);
+        printf("listener: %p\n", (void *)listener);
+        add_ids_listener(listener, &h3ssl);
+
         if (!hassomething) {
             printf("waiting on socket\n");
             fflush(stdout);
@@ -985,40 +1043,12 @@ static int run_quic_server(SSL_CTX *ctx, int fd)
                 goto err;
             }
         }
-        printf("before SSL_accept_connection\n");
-        fflush(stdout);
-
         /*
          * Service the connection. In a real application this would be done
          * concurrently. In this demonstration program a single connection is
          * accepted and serviced at a time.
          */
-
-        /* try to use nghttp3 to send a response */
-        init_ids(&h3ssl);
-        printf("listener: %p\n", (void *)listener);
-        add_id_at(-1, listener, 0, &h3ssl);
-
-        /* Setup callbacks. */
-        callbacks.recv_header = on_recv_header;
-        callbacks.end_headers = on_end_headers;
-        callbacks.recv_data = on_recv_data;
-        callbacks.end_stream = on_end_stream;
-
-        /* mem default */
-        mem = nghttp3_mem_default();
-
-        handle_events_from_ids(&h3ssl);
-
     newconn:
-        if (h3conn == NULL) {
-            nghttp3_settings_default(&settings);
-            if (nghttp3_conn_server_new(&h3conn, &callbacks, &settings, mem,
-                                        &h3ssl)) {
-                fprintf(stderr, "nghttp3_conn_client_new failed!\n");
-                exit(1);
-            }
-        }
 
         printf("process_server starting...\n");
         fflush(stdout);
@@ -1029,9 +1059,7 @@ static int run_quic_server(SSL_CTX *ctx, int fd)
         num_nv = 0;
         while (!h3ssl.end_headers_received) {
             if (!hassomething) {
-                printf("listener: %p waiting for end_headers_received\n",
-                       (void *) h3ssl.ssl_ids[0].s);
-                if (wait_for_activity(h3ssl.ssl_ids[0].s) == 0) {
+                if (wait_for_activity(listener) == 0) {
                     printf("waiting for end_headers_received timeout %d\n", numtimeout);
                     numtimeout++;
                     if (numtimeout == 25)
@@ -1175,8 +1203,11 @@ static int run_quic_server(SSL_CTX *ctx, int fd)
         for (;;) {
 
             if (!hasnothing) {
+                SSL *ssl = get_ids_connection(&h3ssl);
                 printf("hasnothing nothing WAIT %d!!!\n", h3ssl.close_done);
-                ret = wait_for_activity(h3ssl.ssl_ids[1].s);
+                if (ssl == NULL)
+                    ssl = listener;
+                ret = wait_for_activity(ssl);
                 if (ret == -1)
                     goto err;
                 if (ret == 0)
@@ -1220,12 +1251,15 @@ static int run_quic_server(SSL_CTX *ctx, int fd)
         }
 
         /*
-         * Free the connection, then loop again, accepting another connection.
+         * Free the streams, then loop again, accepting another connection.
          */
         close_all_ids(&h3ssl);
-        SSL_free(h3ssl.ssl_ids[1].s);
-        h3ssl.ssl_ids[1].s = NULL;
-        h3conn = NULL; /* XXX need nghttp3_conn_del() ? */
+        ssl = get_ids_connection(&h3ssl);
+        if (ssl != NULL) {
+            SSL_free(ssl);
+            replace_ids_connection(&h3ssl, ssl, NULL);
+        }
+        hassomething = 0;
     }
 
     ok = 1;