]> git.ipfire.org Git - thirdparty/haproxy.git/commitdiff
MINOR: quic: Atomically get/set the connection state
authorFrédéric Lécaille <flecaille@haproxy.com>
Wed, 18 Aug 2021 07:16:01 +0000 (09:16 +0200)
committerAmaury Denoyelle <adenoyelle@haproxy.com>
Thu, 23 Sep 2021 13:27:25 +0000 (15:27 +0200)
As ->state quic_conn struct member field is shared between threads
we must atomically get and set its value.

src/xprt_quic.c

index 04c472202420fe3521f3adc380893b38ae882bb8..96b6a47694e0cc27e3a4f017e8bfb1cf8b28b694 100644 (file)
@@ -594,7 +594,7 @@ static inline int quic_peer_validated_addr(struct ssl_sock_ctx *ctx)
 
        if ((qc->els[QUIC_TLS_ENC_LEVEL_HANDSHAKE].pktns->flags & QUIC_FL_PKTNS_ACK_RECEIVED) ||
            (qc->els[QUIC_TLS_ENC_LEVEL_APP].pktns->flags & QUIC_FL_PKTNS_ACK_RECEIVED) ||
-           (qc->state & QUIC_HS_ST_COMPLETE))
+           HA_ATOMIC_LOAD(&qc->state) >= QUIC_HS_ST_COMPLETE)
                return 1;
 
        return 0;
@@ -608,6 +608,7 @@ static inline void qc_set_timer(struct ssl_sock_ctx *ctx)
        struct quic_conn *qc;
        struct quic_pktns *pktns;
        unsigned int pto;
+       int handshake_complete;
 
        TRACE_ENTER(QUIC_EV_CONN_STIMER, ctx->conn,
                    NULL, NULL, &ctx->conn->qc->path->ifae_pkts);
@@ -629,7 +630,8 @@ static inline void qc_set_timer(struct ssl_sock_ctx *ctx)
                goto out;
        }
 
-       pktns = quic_pto_pktns(qc, qc->state & QUIC_HS_ST_COMPLETE, &pto);
+       handshake_complete = HA_ATOMIC_LOAD(&qc->state) >= QUIC_HS_ST_COMPLETE;
+       pktns = quic_pto_pktns(qc, handshake_complete, &pto);
        if (tick_isset(pto))
                qc->timer = pto;
  out:
@@ -1495,7 +1497,7 @@ static inline int qc_provide_cdata(struct quic_enc_level *el,
                                    struct quic_rx_packet *pkt,
                                    struct quic_rx_crypto_frm *cf)
 {
-       int ssl_err;
+       int ssl_err, state;
        struct quic_conn *qc;
 
        TRACE_ENTER(QUIC_EV_CONN_SSLDATA, ctx->conn);
@@ -1511,43 +1513,44 @@ static inline int qc_provide_cdata(struct quic_enc_level *el,
        TRACE_PROTO("in order CRYPTO data",
                    QUIC_EV_CONN_SSLDATA, ctx->conn,, cf, ctx->ssl);
 
-       if (qc->state < QUIC_HS_ST_COMPLETE) {
+       state = HA_ATOMIC_LOAD(&qc->state);
+       if (state < QUIC_HS_ST_COMPLETE) {
                ssl_err = SSL_do_handshake(ctx->ssl);
                if (ssl_err != 1) {
                        ssl_err = SSL_get_error(ctx->ssl, ssl_err);
                        if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) {
                                TRACE_PROTO("SSL handshake",
-                                           QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state, &ssl_err);
+                                           QUIC_EV_CONN_HDSHK, ctx->conn, &state, &ssl_err);
                                goto out;
                        }
 
                        TRACE_DEVEL("SSL handshake error",
-                                   QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state, &ssl_err);
+                                   QUIC_EV_CONN_HDSHK, ctx->conn, &state, &ssl_err);
                        goto err;
                }
 
-               TRACE_PROTO("SSL handshake OK", QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state);
+               TRACE_PROTO("SSL handshake OK", QUIC_EV_CONN_HDSHK, ctx->conn, &state);
                if (objt_listener(ctx->conn->target))
-                       qc->state = QUIC_HS_ST_CONFIRMED;
+                       HA_ATOMIC_STORE(&qc->state, QUIC_HS_ST_CONFIRMED);
                else
-                       qc->state = QUIC_HS_ST_COMPLETE;
+                       HA_ATOMIC_STORE(&qc->state, QUIC_HS_ST_COMPLETE);
        } else {
                ssl_err = SSL_process_quic_post_handshake(ctx->ssl);
                if (ssl_err != 1) {
                        ssl_err = SSL_get_error(ctx->ssl, ssl_err);
                        if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) {
                                TRACE_DEVEL("SSL post handshake",
-                                           QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state, &ssl_err);
+                                           QUIC_EV_CONN_HDSHK, ctx->conn, &state, &ssl_err);
                                goto out;
                        }
 
                        TRACE_DEVEL("SSL post handshake error",
-                                   QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state, &ssl_err);
+                                   QUIC_EV_CONN_HDSHK, ctx->conn, &state, &ssl_err);
                        goto err;
                }
 
                TRACE_PROTO("SSL post handshake succeeded",
-                           QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state);
+                           QUIC_EV_CONN_HDSHK, ctx->conn, &state);
        }
 
  out:
@@ -1938,7 +1941,7 @@ static int qc_parse_pkt_frms(struct quic_rx_packet *pkt, struct ssl_sock_ctx *ct
                        if (objt_listener(ctx->conn->target))
                                goto err;
 
-                       conn->state = QUIC_HS_ST_CONFIRMED;
+                       HA_ATOMIC_STORE(&conn->state, QUIC_HS_ST_CONFIRMED);
                        break;
                default:
                        goto err;
@@ -1949,12 +1952,12 @@ static int qc_parse_pkt_frms(struct quic_rx_packet *pkt, struct ssl_sock_ctx *ct
         * has successfully parse a Handshake packet. The Initial encryption must also
         * be discarded.
         */
-       if (conn->state == QUIC_HS_ST_SERVER_INITIAL &&
+       if (HA_ATOMIC_LOAD(&conn->state) == QUIC_HS_ST_SERVER_INITIAL &&
            pkt->type == QUIC_PACKET_TYPE_HANDSHAKE) {
                quic_tls_discard_keys(&conn->els[QUIC_TLS_ENC_LEVEL_INITIAL]);
                quic_pktns_discard(conn->els[QUIC_TLS_ENC_LEVEL_INITIAL].pktns, conn);
                qc_set_timer(ctx);
-               conn->state = QUIC_HS_ST_SERVER_HANDSHAKE;
+               HA_ATOMIC_STORE(&conn->state, QUIC_HS_ST_SERVER_HANDSHAKE);
        }
 
        TRACE_LEAVE(QUIC_EV_CONN_PRSHPKT, ctx->conn);
@@ -2004,7 +2007,7 @@ static int qc_prep_hdshk_pkts(struct qring *qr, struct ssl_sock_ctx *ctx)
 
        TRACE_ENTER(QUIC_EV_CONN_PHPKTS, ctx->conn);
        qc = ctx->conn->qc;
-       if (!quic_get_tls_enc_levels(&tel, &next_tel, qc->state)) {
+       if (!quic_get_tls_enc_levels(&tel, &next_tel, HA_ATOMIC_LOAD(&qc->state))) {
                TRACE_DEVEL("unknown enc. levels", QUIC_EV_CONN_PHPKTS, ctx->conn);
                goto err;
        }
@@ -2088,12 +2091,12 @@ static int qc_prep_hdshk_pkts(struct qring *qr, struct ssl_sock_ctx *ctx)
                        /* Discard the Initial encryption keys as soon as
                         * a handshake packet could be built.
                         */
-                       if (qc->state == QUIC_HS_ST_CLIENT_INITIAL &&
+                       if (HA_ATOMIC_LOAD(&qc->state) == QUIC_HS_ST_CLIENT_INITIAL &&
                            pkt_type == QUIC_PACKET_TYPE_HANDSHAKE) {
                                quic_tls_discard_keys(&qc->els[QUIC_TLS_ENC_LEVEL_INITIAL]);
                                quic_pktns_discard(qc->els[QUIC_TLS_ENC_LEVEL_INITIAL].pktns, qc);
                                qc_set_timer(ctx);
-                               qc->state = QUIC_HS_ST_CLIENT_HANDSHAKE;
+                               HA_ATOMIC_STORE(&qc->state, QUIC_HS_ST_CLIENT_HANDSHAKE);
                        }
                        /* Special case for Initial packets: when they have all
                         * been sent, select the next level.
@@ -2478,7 +2481,7 @@ static inline void qc_rm_hp_pkts(struct quic_enc_level *el, struct ssl_sock_ctx
        app_qel = &ctx->conn->qc->els[QUIC_TLS_ENC_LEVEL_APP];
        /* A server must not process incoming 1-RTT packets before the handshake is complete. */
        if (el == app_qel && objt_listener(ctx->conn->target) &&
-           ctx->conn->qc->state < QUIC_HS_ST_COMPLETE) {
+           HA_ATOMIC_LOAD(&ctx->conn->qc->state) < QUIC_HS_ST_COMPLETE) {
                TRACE_PROTO("hp not removed (handshake not completed)",
                            QUIC_EV_CONN_ELRMHP, ctx->conn);
                goto out;
@@ -2622,9 +2625,10 @@ struct task *quic_conn_io_cb(struct task *t, void *context, unsigned int state)
        ctx = context;
        qc = ctx->conn->qc;
        qr = NULL;
-       TRACE_ENTER(QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state);
+       st = HA_ATOMIC_LOAD(&qc->state);
+       TRACE_ENTER(QUIC_EV_CONN_HDSHK, ctx->conn, &st);
        ssl_err = SSL_ERROR_NONE;
-       if (!quic_get_tls_enc_levels(&tel, &next_tel, qc->state))
+       if (!quic_get_tls_enc_levels(&tel, &next_tel, st))
                goto err;
 
        qel = &qc->els[tel];
@@ -2674,15 +2678,14 @@ struct task *quic_conn_io_cb(struct task *t, void *context, unsigned int state)
                goto next_level;
        }
 
- out:
        MT_LIST_APPEND(qc->tx.qring_list, &qr->mt_list);
-       TRACE_LEAVE(QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state);
+       TRACE_LEAVE(QUIC_EV_CONN_HDSHK, ctx->conn, &st);
        return t;
 
  err:
        if (qr)
                MT_LIST_APPEND(qc->tx.qring_list, &qr->mt_list);
-       TRACE_DEVEL("leaving in error", QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state, &ssl_err);
+       TRACE_DEVEL("leaving in error", QUIC_EV_CONN_HDSHK, ctx->conn, &st, &ssl_err);
        return t;
 }
 
@@ -2768,7 +2771,7 @@ static struct task *process_timer(struct task *task, void *ctx, unsigned int sta
        struct ssl_sock_ctx *conn_ctx;
        struct quic_conn *qc;
        struct quic_pktns *pktns;
-
+       int st;
 
        conn_ctx = task->context;
        qc = conn_ctx->conn->qc;
@@ -2786,11 +2789,12 @@ static struct task *process_timer(struct task *task, void *ctx, unsigned int sta
                goto out;
        }
 
+       st = HA_ATOMIC_LOAD(&qc->state);
        if (qc->path->in_flight) {
-               pktns = quic_pto_pktns(qc, qc->state >= QUIC_HS_ST_COMPLETE, NULL);
+               pktns = quic_pto_pktns(qc, st >= QUIC_HS_ST_COMPLETE, NULL);
                pktns->tx.pto_probe = 1;
        }
-       else if (objt_server(qc->conn->target) && qc->state <= QUIC_HS_ST_COMPLETE) {
+       else if (objt_server(qc->conn->target) && st <= QUIC_HS_ST_COMPLETE) {
                struct quic_enc_level *iel = &qc->els[QUIC_TLS_ENC_LEVEL_INITIAL];
                struct quic_enc_level *hel = &qc->els[QUIC_TLS_ENC_LEVEL_HANDSHAKE];
 
@@ -2837,7 +2841,7 @@ static struct quic_conn *qc_new_conn(unsigned int version, int ipv4,
        if (server) {
                struct listener *l = owner;
 
-               qc->state = QUIC_HS_ST_SERVER_INITIAL;
+               HA_ATOMIC_STORE(&qc->state, QUIC_HS_ST_SERVER_INITIAL);
                /* Copy the initial DCID. */
                qc->odcid.len = dcid_len;
                if (qc->odcid.len)
@@ -2851,7 +2855,7 @@ static struct quic_conn *qc_new_conn(unsigned int version, int ipv4,
        }
        /* QUIC Client (outgoing connection to servers) */
        else {
-               qc->state = QUIC_HS_ST_CLIENT_INITIAL;
+               HA_ATOMIC_STORE(&qc->state, QUIC_HS_ST_CLIENT_INITIAL);
                if (dcid_len)
                        memcpy(qc->dcid.data, dcid, dcid_len);
                qc->dcid.len = dcid_len;
@@ -3000,7 +3004,8 @@ static int qc_pkt_may_rm_hp(struct quic_rx_packet *pkt,
        }
 
        if (((*qel)->tls_ctx.rx.flags & QUIC_FL_TLS_SECRETS_SET) &&
-           (tel != QUIC_TLS_ENC_LEVEL_APP || ctx->conn->qc->state >= QUIC_HS_ST_COMPLETE))
+           (tel != QUIC_TLS_ENC_LEVEL_APP ||
+            HA_ATOMIC_LOAD(&ctx->conn->qc->state) >= QUIC_HS_ST_COMPLETE))
                return 1;
 
        return 0;
@@ -4276,14 +4281,15 @@ static int qc_conn_init(struct connection *conn, void **xprt_ctx)
                SSL_set_connect_state(ctx->ssl);
                ssl_err = SSL_do_handshake(ctx->ssl);
                if (ssl_err != 1) {
+                       int st;
+
+                       st = HA_ATOMIC_LOAD(&qc->state);
                        ssl_err = SSL_get_error(ctx->ssl, ssl_err);
                        if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) {
-                               TRACE_PROTO("SSL handshake",
-                                           QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state, &ssl_err);
+                               TRACE_PROTO("SSL handshake", QUIC_EV_CONN_HDSHK, ctx->conn, &st, &ssl_err);
                        }
                        else {
-                               TRACE_DEVEL("SSL handshake error",
-                                           QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state, &ssl_err);
+                               TRACE_DEVEL("SSL handshake error", QUIC_EV_CONN_HDSHK, ctx->conn, &st, &ssl_err);
                                goto err;
                        }
                }