From: Frédéric Lécaille Date: Wed, 18 Aug 2021 07:16:01 +0000 (+0200) Subject: MINOR: quic: Atomically get/set the connection state X-Git-Tag: v2.5-dev8~74 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=eed7a7d73b75db42b29dbc0ffd2736b7d071bd16;p=thirdparty%2Fhaproxy.git MINOR: quic: Atomically get/set the connection state As ->state quic_conn struct member field is shared between threads we must atomically get and set its value. --- diff --git a/src/xprt_quic.c b/src/xprt_quic.c index 04c4722024..96b6a47694 100644 --- a/src/xprt_quic.c +++ b/src/xprt_quic.c @@ -594,7 +594,7 @@ static inline int quic_peer_validated_addr(struct ssl_sock_ctx *ctx) if ((qc->els[QUIC_TLS_ENC_LEVEL_HANDSHAKE].pktns->flags & QUIC_FL_PKTNS_ACK_RECEIVED) || (qc->els[QUIC_TLS_ENC_LEVEL_APP].pktns->flags & QUIC_FL_PKTNS_ACK_RECEIVED) || - (qc->state & QUIC_HS_ST_COMPLETE)) + HA_ATOMIC_LOAD(&qc->state) >= QUIC_HS_ST_COMPLETE) return 1; return 0; @@ -608,6 +608,7 @@ static inline void qc_set_timer(struct ssl_sock_ctx *ctx) struct quic_conn *qc; struct quic_pktns *pktns; unsigned int pto; + int handshake_complete; TRACE_ENTER(QUIC_EV_CONN_STIMER, ctx->conn, NULL, NULL, &ctx->conn->qc->path->ifae_pkts); @@ -629,7 +630,8 @@ static inline void qc_set_timer(struct ssl_sock_ctx *ctx) goto out; } - pktns = quic_pto_pktns(qc, qc->state & QUIC_HS_ST_COMPLETE, &pto); + handshake_complete = HA_ATOMIC_LOAD(&qc->state) >= QUIC_HS_ST_COMPLETE; + pktns = quic_pto_pktns(qc, handshake_complete, &pto); if (tick_isset(pto)) qc->timer = pto; out: @@ -1495,7 +1497,7 @@ static inline int qc_provide_cdata(struct quic_enc_level *el, struct quic_rx_packet *pkt, struct quic_rx_crypto_frm *cf) { - int ssl_err; + int ssl_err, state; struct quic_conn *qc; TRACE_ENTER(QUIC_EV_CONN_SSLDATA, ctx->conn); @@ -1511,43 +1513,44 @@ static inline int qc_provide_cdata(struct quic_enc_level *el, TRACE_PROTO("in order CRYPTO data", QUIC_EV_CONN_SSLDATA, ctx->conn,, cf, ctx->ssl); - if (qc->state < QUIC_HS_ST_COMPLETE) { + state = HA_ATOMIC_LOAD(&qc->state); + if (state < QUIC_HS_ST_COMPLETE) { ssl_err = SSL_do_handshake(ctx->ssl); if (ssl_err != 1) { ssl_err = SSL_get_error(ctx->ssl, ssl_err); if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) { TRACE_PROTO("SSL handshake", - QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state, &ssl_err); + QUIC_EV_CONN_HDSHK, ctx->conn, &state, &ssl_err); goto out; } TRACE_DEVEL("SSL handshake error", - QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state, &ssl_err); + QUIC_EV_CONN_HDSHK, ctx->conn, &state, &ssl_err); goto err; } - TRACE_PROTO("SSL handshake OK", QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state); + TRACE_PROTO("SSL handshake OK", QUIC_EV_CONN_HDSHK, ctx->conn, &state); if (objt_listener(ctx->conn->target)) - qc->state = QUIC_HS_ST_CONFIRMED; + HA_ATOMIC_STORE(&qc->state, QUIC_HS_ST_CONFIRMED); else - qc->state = QUIC_HS_ST_COMPLETE; + HA_ATOMIC_STORE(&qc->state, QUIC_HS_ST_COMPLETE); } else { ssl_err = SSL_process_quic_post_handshake(ctx->ssl); if (ssl_err != 1) { ssl_err = SSL_get_error(ctx->ssl, ssl_err); if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) { TRACE_DEVEL("SSL post handshake", - QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state, &ssl_err); + QUIC_EV_CONN_HDSHK, ctx->conn, &state, &ssl_err); goto out; } TRACE_DEVEL("SSL post handshake error", - QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state, &ssl_err); + QUIC_EV_CONN_HDSHK, ctx->conn, &state, &ssl_err); goto err; } TRACE_PROTO("SSL post handshake succeeded", - QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state); + QUIC_EV_CONN_HDSHK, ctx->conn, &state); } out: @@ -1938,7 +1941,7 @@ static int qc_parse_pkt_frms(struct quic_rx_packet *pkt, struct ssl_sock_ctx *ct if (objt_listener(ctx->conn->target)) goto err; - conn->state = QUIC_HS_ST_CONFIRMED; + HA_ATOMIC_STORE(&conn->state, QUIC_HS_ST_CONFIRMED); break; default: goto err; @@ -1949,12 +1952,12 @@ static int qc_parse_pkt_frms(struct quic_rx_packet *pkt, struct ssl_sock_ctx *ct * has successfully parse a Handshake packet. The Initial encryption must also * be discarded. */ - if (conn->state == QUIC_HS_ST_SERVER_INITIAL && + if (HA_ATOMIC_LOAD(&conn->state) == QUIC_HS_ST_SERVER_INITIAL && pkt->type == QUIC_PACKET_TYPE_HANDSHAKE) { quic_tls_discard_keys(&conn->els[QUIC_TLS_ENC_LEVEL_INITIAL]); quic_pktns_discard(conn->els[QUIC_TLS_ENC_LEVEL_INITIAL].pktns, conn); qc_set_timer(ctx); - conn->state = QUIC_HS_ST_SERVER_HANDSHAKE; + HA_ATOMIC_STORE(&conn->state, QUIC_HS_ST_SERVER_HANDSHAKE); } TRACE_LEAVE(QUIC_EV_CONN_PRSHPKT, ctx->conn); @@ -2004,7 +2007,7 @@ static int qc_prep_hdshk_pkts(struct qring *qr, struct ssl_sock_ctx *ctx) TRACE_ENTER(QUIC_EV_CONN_PHPKTS, ctx->conn); qc = ctx->conn->qc; - if (!quic_get_tls_enc_levels(&tel, &next_tel, qc->state)) { + if (!quic_get_tls_enc_levels(&tel, &next_tel, HA_ATOMIC_LOAD(&qc->state))) { TRACE_DEVEL("unknown enc. levels", QUIC_EV_CONN_PHPKTS, ctx->conn); goto err; } @@ -2088,12 +2091,12 @@ static int qc_prep_hdshk_pkts(struct qring *qr, struct ssl_sock_ctx *ctx) /* Discard the Initial encryption keys as soon as * a handshake packet could be built. */ - if (qc->state == QUIC_HS_ST_CLIENT_INITIAL && + if (HA_ATOMIC_LOAD(&qc->state) == QUIC_HS_ST_CLIENT_INITIAL && pkt_type == QUIC_PACKET_TYPE_HANDSHAKE) { quic_tls_discard_keys(&qc->els[QUIC_TLS_ENC_LEVEL_INITIAL]); quic_pktns_discard(qc->els[QUIC_TLS_ENC_LEVEL_INITIAL].pktns, qc); qc_set_timer(ctx); - qc->state = QUIC_HS_ST_CLIENT_HANDSHAKE; + HA_ATOMIC_STORE(&qc->state, QUIC_HS_ST_CLIENT_HANDSHAKE); } /* Special case for Initial packets: when they have all * been sent, select the next level. @@ -2478,7 +2481,7 @@ static inline void qc_rm_hp_pkts(struct quic_enc_level *el, struct ssl_sock_ctx app_qel = &ctx->conn->qc->els[QUIC_TLS_ENC_LEVEL_APP]; /* A server must not process incoming 1-RTT packets before the handshake is complete. */ if (el == app_qel && objt_listener(ctx->conn->target) && - ctx->conn->qc->state < QUIC_HS_ST_COMPLETE) { + HA_ATOMIC_LOAD(&ctx->conn->qc->state) < QUIC_HS_ST_COMPLETE) { TRACE_PROTO("hp not removed (handshake not completed)", QUIC_EV_CONN_ELRMHP, ctx->conn); goto out; @@ -2622,9 +2625,10 @@ struct task *quic_conn_io_cb(struct task *t, void *context, unsigned int state) ctx = context; qc = ctx->conn->qc; qr = NULL; - TRACE_ENTER(QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state); + st = HA_ATOMIC_LOAD(&qc->state); + TRACE_ENTER(QUIC_EV_CONN_HDSHK, ctx->conn, &st); ssl_err = SSL_ERROR_NONE; - if (!quic_get_tls_enc_levels(&tel, &next_tel, qc->state)) + if (!quic_get_tls_enc_levels(&tel, &next_tel, st)) goto err; qel = &qc->els[tel]; @@ -2674,15 +2678,14 @@ struct task *quic_conn_io_cb(struct task *t, void *context, unsigned int state) goto next_level; } - out: MT_LIST_APPEND(qc->tx.qring_list, &qr->mt_list); - TRACE_LEAVE(QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state); + TRACE_LEAVE(QUIC_EV_CONN_HDSHK, ctx->conn, &st); return t; err: if (qr) MT_LIST_APPEND(qc->tx.qring_list, &qr->mt_list); - TRACE_DEVEL("leaving in error", QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state, &ssl_err); + TRACE_DEVEL("leaving in error", QUIC_EV_CONN_HDSHK, ctx->conn, &st, &ssl_err); return t; } @@ -2768,7 +2771,7 @@ static struct task *process_timer(struct task *task, void *ctx, unsigned int sta struct ssl_sock_ctx *conn_ctx; struct quic_conn *qc; struct quic_pktns *pktns; - + int st; conn_ctx = task->context; qc = conn_ctx->conn->qc; @@ -2786,11 +2789,12 @@ static struct task *process_timer(struct task *task, void *ctx, unsigned int sta goto out; } + st = HA_ATOMIC_LOAD(&qc->state); if (qc->path->in_flight) { - pktns = quic_pto_pktns(qc, qc->state >= QUIC_HS_ST_COMPLETE, NULL); + pktns = quic_pto_pktns(qc, st >= QUIC_HS_ST_COMPLETE, NULL); pktns->tx.pto_probe = 1; } - else if (objt_server(qc->conn->target) && qc->state <= QUIC_HS_ST_COMPLETE) { + else if (objt_server(qc->conn->target) && st <= QUIC_HS_ST_COMPLETE) { struct quic_enc_level *iel = &qc->els[QUIC_TLS_ENC_LEVEL_INITIAL]; struct quic_enc_level *hel = &qc->els[QUIC_TLS_ENC_LEVEL_HANDSHAKE]; @@ -2837,7 +2841,7 @@ static struct quic_conn *qc_new_conn(unsigned int version, int ipv4, if (server) { struct listener *l = owner; - qc->state = QUIC_HS_ST_SERVER_INITIAL; + HA_ATOMIC_STORE(&qc->state, QUIC_HS_ST_SERVER_INITIAL); /* Copy the initial DCID. */ qc->odcid.len = dcid_len; if (qc->odcid.len) @@ -2851,7 +2855,7 @@ static struct quic_conn *qc_new_conn(unsigned int version, int ipv4, } /* QUIC Client (outgoing connection to servers) */ else { - qc->state = QUIC_HS_ST_CLIENT_INITIAL; + HA_ATOMIC_STORE(&qc->state, QUIC_HS_ST_CLIENT_INITIAL); if (dcid_len) memcpy(qc->dcid.data, dcid, dcid_len); qc->dcid.len = dcid_len; @@ -3000,7 +3004,8 @@ static int qc_pkt_may_rm_hp(struct quic_rx_packet *pkt, } if (((*qel)->tls_ctx.rx.flags & QUIC_FL_TLS_SECRETS_SET) && - (tel != QUIC_TLS_ENC_LEVEL_APP || ctx->conn->qc->state >= QUIC_HS_ST_COMPLETE)) + (tel != QUIC_TLS_ENC_LEVEL_APP || + HA_ATOMIC_LOAD(&ctx->conn->qc->state) >= QUIC_HS_ST_COMPLETE)) return 1; return 0; @@ -4276,14 +4281,15 @@ static int qc_conn_init(struct connection *conn, void **xprt_ctx) SSL_set_connect_state(ctx->ssl); ssl_err = SSL_do_handshake(ctx->ssl); if (ssl_err != 1) { + int st; + + st = HA_ATOMIC_LOAD(&qc->state); ssl_err = SSL_get_error(ctx->ssl, ssl_err); if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) { - TRACE_PROTO("SSL handshake", - QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state, &ssl_err); + TRACE_PROTO("SSL handshake", QUIC_EV_CONN_HDSHK, ctx->conn, &st, &ssl_err); } else { - TRACE_DEVEL("SSL handshake error", - QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state, &ssl_err); + TRACE_DEVEL("SSL handshake error", QUIC_EV_CONN_HDSHK, ctx->conn, &st, &ssl_err); goto err; } }