From: Willy Tarreau Date: Mon, 11 Apr 2022 09:29:11 +0000 (+0200) Subject: MEDIUM: ssl: stop using conn->xprt_ctx to access the ssl_sock_ctx X-Git-Tag: v2.6-dev6~127 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=939b0bf866b349391734dcb29d730bf9057e8d7d;p=thirdparty%2Fhaproxy.git MEDIUM: ssl: stop using conn->xprt_ctx to access the ssl_sock_ctx The SSL functions must not use conn->xprt_ctx anymore but find the context by calling conn_get_ssl_sock_ctx(), which will properly pass through the transport layers to retrieve the desired information. Otherwise when the functions are called on a QUIC connection, they refuse to work for not being called on the proper transport. --- diff --git a/src/ssl_sample.c b/src/ssl_sample.c index a18d66c357..fe2817baee 100644 --- a/src/ssl_sample.c +++ b/src/ssl_sample.c @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -491,15 +492,12 @@ smp_fetch_ssl_fc_has_early(const struct arg *args, struct sample *smp, const cha static int smp_fetch_ssl_fc_has_crt(const struct arg *args, struct sample *smp, const char *kw, void *private) { - struct connection *conn; - struct ssl_sock_ctx *ctx; + struct connection *conn = objt_conn(smp->sess->origin); + struct ssl_sock_ctx *ctx = conn_get_ssl_sock_ctx(conn); - conn = objt_conn(smp->sess->origin); - if (!conn || conn->xprt != &ssl_sock) + if (!ctx) return 0; - ctx = conn->xprt_ctx; - if (conn->flags & CO_FL_WAIT_XPRT) { smp->flags |= SMP_F_MAY_CHANGE; return 0; @@ -1177,7 +1175,7 @@ smp_fetch_ssl_fc(const struct arg *args, struct sample *smp, const char *kw, voi smp->strm ? cs_conn(smp->strm->csb) : NULL; smp->data.type = SMP_T_BOOL; - smp->data.u.sint = (conn && conn->xprt == &ssl_sock); + smp->data.u.sint = conn_is_ssl(conn); return 1; } @@ -1657,9 +1655,9 @@ smp_fetch_ssl_fc_err(const struct arg *args, struct sample *smp, const char *kw, conn = (kw[4] != 'b') ? objt_conn(smp->sess->origin) : smp->strm ? cs_conn(smp->strm->csb) : NULL; - if (!conn || conn->xprt != &ssl_sock) + ctx = conn_get_ssl_sock_ctx(conn); + if (!ctx) return 0; - ctx = conn->xprt_ctx; if (conn->flags & CO_FL_WAIT_XPRT && !conn->err_code) { smp->flags = SMP_F_MAY_CHANGE; @@ -1710,9 +1708,9 @@ smp_fetch_ssl_fc_err_str(const struct arg *args, struct sample *smp, const char conn = (kw[4] != 'b') ? objt_conn(smp->sess->origin) : smp->strm ? cs_conn(smp->strm->csb) : NULL; - if (!conn || conn->xprt != &ssl_sock) + ctx = conn_get_ssl_sock_ctx(conn); + if (!ctx) return 0; - ctx = conn->xprt_ctx; if (conn->flags & CO_FL_WAIT_XPRT && !conn->err_code) { smp->flags = SMP_F_MAY_CHANGE; @@ -1976,15 +1974,10 @@ smp_fetch_ssl_fc_unique_id(const struct arg *args, struct sample *smp, const cha static int smp_fetch_ssl_c_ca_err(const struct arg *args, struct sample *smp, const char *kw, void *private) { - struct connection *conn; - struct ssl_sock_ctx *ctx; - - conn = objt_conn(smp->sess->origin); - if (!conn || conn->xprt != &ssl_sock) - return 0; - ctx = conn->xprt_ctx; + struct connection *conn = objt_conn(smp->sess->origin); + struct ssl_sock_ctx *ctx = conn_get_ssl_sock_ctx(conn); - if (conn->flags & CO_FL_WAIT_XPRT && !conn->err_code) { + if (conn && conn->flags & CO_FL_WAIT_XPRT && !conn->err_code) { smp->flags = SMP_F_MAY_CHANGE; return 0; } @@ -2003,18 +1996,13 @@ smp_fetch_ssl_c_ca_err(const struct arg *args, struct sample *smp, const char *k static int smp_fetch_ssl_c_ca_err_depth(const struct arg *args, struct sample *smp, const char *kw, void *private) { - struct connection *conn; - struct ssl_sock_ctx *ctx; - - conn = objt_conn(smp->sess->origin); - if (!conn || conn->xprt != &ssl_sock) - return 0; + struct connection *conn = objt_conn(smp->sess->origin); + struct ssl_sock_ctx *ctx = conn_get_ssl_sock_ctx(conn); - if (conn->flags & CO_FL_WAIT_XPRT && !conn->err_code) { + if (conn && conn->flags & CO_FL_WAIT_XPRT && !conn->err_code) { smp->flags = SMP_F_MAY_CHANGE; return 0; } - ctx = conn->xprt_ctx; if (!ctx) return 0; @@ -2030,20 +2018,14 @@ smp_fetch_ssl_c_ca_err_depth(const struct arg *args, struct sample *smp, const c static int smp_fetch_ssl_c_err(const struct arg *args, struct sample *smp, const char *kw, void *private) { - struct connection *conn; - struct ssl_sock_ctx *ctx; - - conn = objt_conn(smp->sess->origin); - if (!conn || conn->xprt != &ssl_sock) - return 0; + struct connection *conn = objt_conn(smp->sess->origin); + struct ssl_sock_ctx *ctx = conn_get_ssl_sock_ctx(conn); - if (conn->flags & CO_FL_WAIT_XPRT && !conn->err_code) { + if (conn && conn->flags & CO_FL_WAIT_XPRT && !conn->err_code) { smp->flags = SMP_F_MAY_CHANGE; return 0; } - ctx = conn->xprt_ctx; - if (!ctx) return 0; diff --git a/src/ssl_sock.c b/src/ssl_sock.c index 232d1ce728..26940dbb95 100644 --- a/src/ssl_sock.c +++ b/src/ssl_sock.c @@ -605,10 +605,9 @@ static struct ssl_sock_ctx *ssl_sock_get_ctx(struct connection *conn) SSL *ssl_sock_get_ssl_object(struct connection *conn) { - if (!conn_is_ssl(conn)) - return NULL; + struct ssl_sock_ctx *ctx = conn_get_ssl_sock_ctx(conn); - return ((struct ssl_sock_ctx *)(conn->xprt_ctx))->ssl; + return ctx ? ctx->ssl : NULL; } /* * This function gives the detail of the SSL error. It is used only @@ -1576,7 +1575,7 @@ void ssl_sock_infocbk(const SSL *ssl, int where, int ret) (void)ret; /* shut gcc stupid warning */ if (conn) - ctx = conn->xprt_ctx; + ctx = conn_get_ssl_sock_ctx(conn); #ifdef USE_QUIC else if (qc) ctx = qc->xprt_ctx; @@ -1633,7 +1632,8 @@ int ssl_sock_bind_verifycbk(int ok, X509_STORE_CTX *x_store) conn = SSL_get_ex_data(ssl, ssl_app_data_index); client_crt = SSL_get_ex_data(ssl, ssl_client_crt_ref_index); - ctx = conn->xprt_ctx; + ctx = conn_get_ssl_sock_ctx(conn); + ALREADY_CHECKED(ctx); ctx->xprt_st |= SSL_SOCK_ST_FL_VERIFY_DONE; @@ -1709,7 +1709,7 @@ static void ssl_sock_parse_heartbeat(struct connection *conn, int write_p, int v /* test heartbeat received (write_p is set to 0 for a received record) */ if ((content_type == TLS1_RT_HEARTBEAT) && (write_p == 0)) { - struct ssl_sock_ctx *ctx = conn->xprt_ctx; + struct ssl_sock_ctx *ctx = conn_get_ssl_sock_ctx(conn); const unsigned char *p = buf; unsigned int payload; @@ -4978,7 +4978,8 @@ static int ssl_sock_srv_verifycbk(int ok, X509_STORE_CTX *ctx) ssl = X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx()); conn = SSL_get_ex_data(ssl, ssl_app_data_index); - ssl_ctx = conn->xprt_ctx; + ssl_ctx = conn_get_ssl_sock_ctx(conn); + ALREADY_CHECKED(ssl_ctx); /* We're checking if the provided hostnames match the desired one. The * desired hostname comes from the SNI we presented if any, or if not @@ -5812,7 +5813,7 @@ err: */ static int ssl_sock_handshake(struct connection *conn, unsigned int flag) { - struct ssl_sock_ctx *ctx = conn->xprt_ctx; + struct ssl_sock_ctx *ctx = conn_get_ssl_sock_ctx(conn); int ret; struct ssl_counters *counters = NULL; struct ssl_counters *counters_px = NULL; @@ -5845,7 +5846,7 @@ static int ssl_sock_handshake(struct connection *conn, unsigned int flag) break; } - if (!conn->xprt_ctx) + if (!ctx) goto out_error; /* don't start calculating a handshake on a dead connection */ @@ -6477,8 +6478,8 @@ static size_t ssl_sock_to_buf(struct connection *conn, void *xprt_ctx, struct bu } else if (ret == SSL_ERROR_ZERO_RETURN) goto read0; else if (ret == SSL_ERROR_SSL) { - struct ssl_sock_ctx *ctx = conn->xprt_ctx; - if (!ctx->error_code) + struct ssl_sock_ctx *ctx = conn_get_ssl_sock_ctx(conn); + if (ctx && !ctx->error_code) ctx->error_code = ERR_peek_error(); conn->err_code = CO_ERR_SSL_FATAL; } @@ -6644,7 +6645,9 @@ static size_t ssl_sock_from_buf(struct connection *conn, void *xprt_ctx, const s break; } else if (ret == SSL_ERROR_SSL || ret == SSL_ERROR_SYSCALL) { - struct ssl_sock_ctx *ctx = conn->xprt_ctx; + struct ssl_sock_ctx *ctx = conn_get_ssl_sock_ctx(conn); + + ALREADY_CHECKED(ctx); if (!ctx->error_code) ctx->error_code = ERR_peek_error(); conn->err_code = CO_ERR_SSL_FATAL; @@ -6759,14 +6762,11 @@ static void ssl_sock_shutw(struct connection *conn, void *xprt_ctx, int clean) /* used for ppv2 pkey algo (can be used for logging) */ int ssl_sock_get_pkey_algo(struct connection *conn, struct buffer *out) { - struct ssl_sock_ctx *ctx; + struct ssl_sock_ctx *ctx = conn_get_ssl_sock_ctx(conn); X509 *crt; - if (!conn_is_ssl(conn)) + if (!ctx) return 0; - - ctx = conn->xprt_ctx; - crt = SSL_get_certificate(ctx->ssl); if (!crt) return 0; @@ -6777,14 +6777,13 @@ int ssl_sock_get_pkey_algo(struct connection *conn, struct buffer *out) /* used for ppv2 cert signature (can be used for logging) */ const char *ssl_sock_get_cert_sig(struct connection *conn) { - struct ssl_sock_ctx *ctx; + struct ssl_sock_ctx *ctx = conn_get_ssl_sock_ctx(conn); __OPENSSL_110_CONST__ ASN1_OBJECT *algorithm; X509 *crt; - if (!conn_is_ssl(conn)) + if (!ctx) return NULL; - ctx = conn->xprt_ctx; crt = SSL_get_certificate(ctx->ssl); if (!crt) return NULL; @@ -6796,11 +6795,10 @@ const char *ssl_sock_get_cert_sig(struct connection *conn) const char *ssl_sock_get_sni(struct connection *conn) { #ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME - struct ssl_sock_ctx *ctx; + struct ssl_sock_ctx *ctx = conn_get_ssl_sock_ctx(conn); - if (!conn_is_ssl(conn)) + if (!ctx) return NULL; - ctx = conn->xprt_ctx; return SSL_get_servername(ctx->ssl, TLSEXT_NAMETYPE_host_name); #else return NULL; @@ -6810,33 +6808,30 @@ const char *ssl_sock_get_sni(struct connection *conn) /* used for logging/ppv2, may be changed for a sample fetch later */ const char *ssl_sock_get_cipher_name(struct connection *conn) { - struct ssl_sock_ctx *ctx; + struct ssl_sock_ctx *ctx = conn_get_ssl_sock_ctx(conn); - if (!conn_is_ssl(conn)) + if (!ctx) return NULL; - ctx = conn->xprt_ctx; return SSL_get_cipher_name(ctx->ssl); } /* used for logging/ppv2, may be changed for a sample fetch later */ const char *ssl_sock_get_proto_version(struct connection *conn) { - struct ssl_sock_ctx *ctx; + struct ssl_sock_ctx *ctx = conn_get_ssl_sock_ctx(conn); - if (!conn_is_ssl(conn)) + if (!ctx) return NULL; - ctx = conn->xprt_ctx; return SSL_get_version(ctx->ssl); } void ssl_sock_set_alpn(struct connection *conn, const unsigned char *alpn, int len) { #ifdef TLSEXT_TYPE_application_layer_protocol_negotiation - struct ssl_sock_ctx *ctx; + struct ssl_sock_ctx *ctx = conn_get_ssl_sock_ctx(conn); - if (!conn_is_ssl(conn)) + if (!ctx) return; - ctx = conn->xprt_ctx; SSL_set_alpn_protos(ctx->ssl, alpn, len); #endif } @@ -6847,17 +6842,16 @@ void ssl_sock_set_alpn(struct connection *conn, const unsigned char *alpn, int l void ssl_sock_set_servername(struct connection *conn, const char *hostname) { #ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME - struct ssl_sock_ctx *ctx; + struct ssl_sock_ctx *ctx = conn_get_ssl_sock_ctx(conn); struct server *s; char *prev_name; - if (!conn_is_ssl(conn)) + if (!ctx) return; BUG_ON(!(conn->flags & CO_FL_WAIT_L6_CONN)); BUG_ON(!(conn->flags & CO_FL_SSL_WAIT_HS)); - ctx = conn->xprt_ctx; s = __objt_server(conn->target); /* if the SNI changes, we must destroy the reusable context so that a @@ -6886,7 +6880,7 @@ void ssl_sock_set_servername(struct connection *conn, const char *hostname) int ssl_sock_get_remote_common_name(struct connection *conn, struct buffer *dest) { - struct ssl_sock_ctx *ctx; + struct ssl_sock_ctx *ctx = conn_get_ssl_sock_ctx(conn); X509 *crt = NULL; X509_NAME *name; const char find_cn[] = "CN"; @@ -6896,9 +6890,8 @@ int ssl_sock_get_remote_common_name(struct connection *conn, }; int result = -1; - if (!conn_is_ssl(conn)) + if (!ctx) goto out; - ctx = conn->xprt_ctx; /* SSL_get_peer_certificate, it increase X509 * ref count */ crt = SSL_get_peer_certificate(ctx->ssl); @@ -6920,12 +6913,11 @@ out: /* returns 1 if client passed a certificate for this session, 0 if not */ int ssl_sock_get_cert_used_sess(struct connection *conn) { - struct ssl_sock_ctx *ctx; + struct ssl_sock_ctx *ctx = conn_get_ssl_sock_ctx(conn); X509 *crt = NULL; - if (!conn_is_ssl(conn)) + if (!ctx) return 0; - ctx = conn->xprt_ctx; /* SSL_get_peer_certificate, it increase X509 * ref count */ crt = SSL_get_peer_certificate(ctx->ssl); @@ -6939,22 +6931,20 @@ int ssl_sock_get_cert_used_sess(struct connection *conn) /* returns 1 if client passed a certificate for this connection, 0 if not */ int ssl_sock_get_cert_used_conn(struct connection *conn) { - struct ssl_sock_ctx *ctx; + struct ssl_sock_ctx *ctx = conn_get_ssl_sock_ctx(conn); - if (!conn_is_ssl(conn)) + if (!ctx) return 0; - ctx = conn->xprt_ctx; return SSL_SOCK_ST_FL_VERIFY_DONE & ctx->xprt_st ? 1 : 0; } /* returns result from SSL verify */ unsigned int ssl_sock_get_verify_result(struct connection *conn) { - struct ssl_sock_ctx *ctx; + struct ssl_sock_ctx *ctx = conn_get_ssl_sock_ctx(conn); - if (!conn_is_ssl(conn)) + if (!ctx) return (unsigned int)X509_V_ERR_APPLICATION_VERIFICATION; - ctx = conn->xprt_ctx; return (unsigned int)SSL_get_verify_result(ctx->ssl); }