From: Stephan Bosch Date: Fri, 20 Oct 2023 23:27:23 +0000 (+0200) Subject: auth: sasl-server-request - Put sasl_server_request in its own pool and add refcounting X-Git-Tag: 2.4.2~198 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=85e7106a5c812c21e30307904542aeebbca16819;p=thirdparty%2Fdovecot%2Fcore.git auth: sasl-server-request - Put sasl_server_request in its own pool and add refcounting --- diff --git a/src/auth/sasl-server-private.h b/src/auth/sasl-server-private.h index b1c5f6b7e7..f8a5c7b28b 100644 --- a/src/auth/sasl-server-private.h +++ b/src/auth/sasl-server-private.h @@ -11,6 +11,7 @@ enum sasl_server_passdb_type { struct sasl_server_request { pool_t pool; + int refcount; struct sasl_server_instance *sinst; struct sasl_server_req_ctx *rctx; struct sasl_server_mech_request *mech; diff --git a/src/auth/sasl-server-protected.h b/src/auth/sasl-server-protected.h index ccb1fa3607..08d4d90fd8 100644 --- a/src/auth/sasl-server-protected.h +++ b/src/auth/sasl-server-protected.h @@ -102,6 +102,9 @@ void mech_deinit(const struct auth_settings *set); * Request */ +void sasl_server_mech_request_ref(struct sasl_server_mech_request *mreq); +void sasl_server_mech_request_unref(struct sasl_server_mech_request **_mreq); + bool sasl_server_request_set_authid(struct sasl_server_mech_request *mreq, enum sasl_server_authid_type authid_type, const char *authid); diff --git a/src/auth/sasl-server-request.c b/src/auth/sasl-server-request.c index fd732990c8..461cd2c70d 100644 --- a/src/auth/sasl-server-request.c +++ b/src/auth/sasl-server-request.c @@ -26,9 +26,11 @@ void sasl_server_request_create(struct sasl_server_req_ctx *rctx, i_zero(rctx); - pool = request->pool; + pool = pool_alloconly_create( + MEMPOOL_GROWING"sasl_server_request", 2048); req = p_new(pool, struct sasl_server_request, 1); req->pool = pool; + req->refcount = 1; req->sinst = sinst; req->rctx = rctx; @@ -47,7 +49,7 @@ void sasl_server_request_create(struct sasl_server_req_ctx *rctx, mreq->set = &sinst->set; mreq->mech = mech; mreq->mech_event = event_parent; - mreq->protocol = p_strdup(mreq->pool, protocol); + mreq->protocol = p_strdup(pool, protocol); req->mech = mreq; rctx->mech = mech; @@ -55,17 +57,28 @@ void sasl_server_request_create(struct sasl_server_req_ctx *rctx, rctx->request = req; } -void sasl_server_request_destroy(struct sasl_server_req_ctx *rctx) +void sasl_server_mech_request_ref(struct sasl_server_mech_request *mreq) { - struct sasl_server_request *req = rctx->request; + i_assert(mreq->req->refcount > 0); + mreq->req->refcount++; +} - i_zero(rctx); - if (req == NULL) +void sasl_server_mech_request_unref(struct sasl_server_mech_request **_mreq) +{ + struct sasl_server_mech_request *mreq = *_mreq; + + *_mreq = NULL; + if (mreq == NULL) + return; + + struct sasl_server_request *req = mreq->req; + + i_assert(req->refcount > 0); + if (--req->refcount > 0) return; struct sasl_server_instance *sinst = req->sinst; struct sasl_server *server = sinst->server; - struct sasl_server_mech_request *mreq = req->mech; i_assert(sinst->requests > 0); sinst->requests--; @@ -74,6 +87,41 @@ void sasl_server_request_destroy(struct sasl_server_req_ctx *rctx) if (mreq->mech->def->funcs->auth_free != NULL) mreq->mech->def->funcs->auth_free(mreq); + + if (req->rctx != NULL) + i_zero(req->rctx); + pool_unref(&req->pool); +} + +void sasl_server_request_ref(struct sasl_server_req_ctx *rctx) +{ + sasl_server_mech_request_ref(rctx->request->mech); +} + +void sasl_server_request_unref(struct sasl_server_req_ctx *rctx) +{ + struct sasl_server_request *req = rctx->request; + + i_zero(rctx); + if (req == NULL) + return; + + struct sasl_server_mech_request *mreq = req->mech; + + sasl_server_mech_request_unref(&mreq); +} + +void sasl_server_request_destroy(struct sasl_server_req_ctx *rctx) +{ + struct sasl_server_request *req = rctx->request; + + if (req == NULL) { + i_zero(rctx); + return; + } + + req->rctx = NULL; + sasl_server_request_unref(rctx); } static bool @@ -102,8 +150,10 @@ void sasl_server_request_initial(struct sasl_server_req_ctx *rctx, if (sasl_server_request_fail_on_nuls(req, data, data_size)) return; + sasl_server_mech_request_ref(mreq); i_assert(mech->def->funcs->auth_initial != NULL); mech->def->funcs->auth_initial(mreq, data, data_size); + sasl_server_mech_request_unref(&mreq); } void sasl_server_request_input(struct sasl_server_req_ctx *rctx, @@ -116,8 +166,10 @@ void sasl_server_request_input(struct sasl_server_req_ctx *rctx, if (sasl_server_request_fail_on_nuls(req, data, data_size)) return; + sasl_server_mech_request_ref(mreq); i_assert(mech->def->funcs->auth_continue != NULL); mech->def->funcs->auth_continue(mreq, data, data_size); + sasl_server_mech_request_unref(&mreq); } void sasl_server_request_test_set_authid(struct sasl_server_req_ctx *rctx, @@ -142,6 +194,7 @@ bool sasl_server_request_set_authid(struct sasl_server_mech_request *mreq, mreq->authid = p_strdup(req->pool, authid); + i_assert(req->rctx != NULL); i_assert(funcs->request_set_authid != NULL); return funcs->request_set_authid(req->rctx, authid_type, authid); } @@ -153,6 +206,7 @@ bool sasl_server_request_set_authzid(struct sasl_server_mech_request *mreq, struct sasl_server *server = req->sinst->server; const struct sasl_server_request_funcs *funcs = server->funcs; + i_assert(req->rctx != NULL); i_assert(funcs->request_set_authzid != NULL); return funcs->request_set_authzid(req->rctx, authzid); } @@ -167,6 +221,7 @@ void sasl_server_request_set_realm(struct sasl_server_mech_request *mreq, i_assert(mreq->realm == NULL); mreq->realm = p_strdup(req->pool, realm); + i_assert(req->rctx != NULL); i_assert(funcs->request_set_realm != NULL); funcs->request_set_realm(req->rctx, realm); } @@ -179,6 +234,7 @@ bool sasl_server_request_get_extra_field(struct sasl_server_mech_request *mreq, struct sasl_server *server = req->sinst->server; const struct sasl_server_request_funcs *funcs = server->funcs; + i_assert(req->rctx != NULL); if (funcs->request_get_extra_field == NULL) { *field_r = NULL; return FALSE; @@ -193,6 +249,7 @@ void sasl_server_request_start_channel_binding( struct sasl_server *server = req->sinst->server; const struct sasl_server_request_funcs *funcs = server->funcs; + i_assert(req->rctx != NULL); i_assert(funcs->request_start_channel_binding != NULL); funcs->request_start_channel_binding(req->rctx, type); } @@ -204,6 +261,7 @@ int sasl_server_request_accept_channel_binding( struct sasl_server *server = req->sinst->server; const struct sasl_server_request_funcs *funcs = server->funcs; + i_assert(req->rctx != NULL); i_assert(funcs->request_accept_channel_binding != NULL); return funcs->request_accept_channel_binding(req->rctx, data_r); } @@ -215,6 +273,8 @@ void sasl_server_request_output(struct sasl_server_mech_request *mreq, struct sasl_server *server = req->sinst->server; const struct sasl_server_request_funcs *funcs = server->funcs; + i_assert(req->rctx != NULL); + const struct sasl_server_output output = { .status = SASL_SERVER_OUTPUT_CONTINUE, .data = data, @@ -231,6 +291,8 @@ void sasl_server_request_success(struct sasl_server_mech_request *mreq, struct sasl_server *server = req->sinst->server; const struct sasl_server_request_funcs *funcs = server->funcs; + i_assert(req->rctx != NULL); + const struct sasl_server_output output = { .status = SASL_SERVER_OUTPUT_SUCCESS, .data = data, @@ -249,6 +311,8 @@ sasl_server_request_failure_common(struct sasl_server_mech_request *mreq, struct sasl_server *server = req->sinst->server; const struct sasl_server_request_funcs *funcs = server->funcs; + i_assert(req->rctx != NULL); + const struct sasl_server_output output = { .status = status, .data = data, @@ -297,6 +361,8 @@ void sasl_server_request_verify_plain( struct sasl_server *server = req->sinst->server; const struct sasl_server_request_funcs *funcs = server->funcs; + i_assert(req->rctx != NULL); + req->passdb_type = SASL_SERVER_PASSDB_TYPE_VERIFY_PLAIN; req->passdb_callback = callback; @@ -324,6 +390,8 @@ void sasl_server_request_lookup_credentials( struct sasl_server *server = req->sinst->server; const struct sasl_server_request_funcs *funcs = server->funcs; + i_assert(req->rctx != NULL); + req->passdb_type = SASL_SERVER_PASSDB_TYPE_LOOKUP_CREDENTIALS; req->passdb_callback = callback; @@ -351,6 +419,8 @@ void sasl_server_request_set_credentials( struct sasl_server *server = req->sinst->server; const struct sasl_server_request_funcs *funcs = server->funcs; + i_assert(req->rctx != NULL); + req->passdb_type = SASL_SERVER_PASSDB_TYPE_SET_CREDENTIALS; req->passdb_callback = callback; diff --git a/src/auth/sasl-server.h b/src/auth/sasl-server.h index d541deea8f..97ddf25fae 100644 --- a/src/auth/sasl-server.h +++ b/src/auth/sasl-server.h @@ -134,6 +134,8 @@ void sasl_server_request_create(struct sasl_server_req_ctx *rctx, const struct sasl_server_mech *mech, const char *protocol, struct event *event_parent); +void sasl_server_request_ref(struct sasl_server_req_ctx *rctx); +void sasl_server_request_unref(struct sasl_server_req_ctx *rctx); void sasl_server_request_destroy(struct sasl_server_req_ctx *rctx); void sasl_server_request_initial(struct sasl_server_req_ctx *rctx,