]> git.ipfire.org Git - thirdparty/dovecot/core.git/commitdiff
auth: sasl-server-request - Put sasl_server_request in its own pool and add refcounting
authorStephan Bosch <stephan.bosch@open-xchange.com>
Fri, 20 Oct 2023 23:27:23 +0000 (01:27 +0200)
committertimo.sirainen <timo.sirainen@open-xchange.com>
Thu, 9 Oct 2025 08:41:22 +0000 (08:41 +0000)
src/auth/sasl-server-private.h
src/auth/sasl-server-protected.h
src/auth/sasl-server-request.c
src/auth/sasl-server.h

index b1c5f6b7e79867954a3f1a3e34e3b54f8e0e7133..f8a5c7b28b58176d92086ef4bb3c8e8816da23ad 100644 (file)
@@ -11,6 +11,7 @@ enum sasl_server_passdb_type {
 
 struct sasl_server_request {
        pool_t pool;
+       int refcount;
        struct sasl_server_instance *sinst;
        struct sasl_server_req_ctx *rctx;
        struct sasl_server_mech_request *mech;
index ccb1fa3607914ca5d73c92ace6963a4c46ef670e..08d4d90fd8ce29b5582c3a5d88b32b86043d5f5d 100644 (file)
@@ -102,6 +102,9 @@ void mech_deinit(const struct auth_settings *set);
  * Request
  */
 
+void sasl_server_mech_request_ref(struct sasl_server_mech_request *mreq);
+void sasl_server_mech_request_unref(struct sasl_server_mech_request **_mreq);
+
 bool sasl_server_request_set_authid(struct sasl_server_mech_request *mreq,
                                    enum sasl_server_authid_type authid_type,
                                    const char *authid);
index fd732990c8c0febff92181ae2ffa33147f4f6a28..461cd2c70d453a9498b48d28a01d34f61a101266 100644 (file)
@@ -26,9 +26,11 @@ void sasl_server_request_create(struct sasl_server_req_ctx *rctx,
 
        i_zero(rctx);
 
-       pool = request->pool;
+       pool = pool_alloconly_create(
+               MEMPOOL_GROWING"sasl_server_request", 2048);
        req = p_new(pool, struct sasl_server_request, 1);
        req->pool = pool;
+       req->refcount = 1;
        req->sinst = sinst;
        req->rctx = rctx;
 
@@ -47,7 +49,7 @@ void sasl_server_request_create(struct sasl_server_req_ctx *rctx,
        mreq->set = &sinst->set;
        mreq->mech = mech;
        mreq->mech_event = event_parent;
-       mreq->protocol = p_strdup(mreq->pool, protocol);
+       mreq->protocol = p_strdup(pool, protocol);
 
        req->mech = mreq;
        rctx->mech = mech;
@@ -55,17 +57,28 @@ void sasl_server_request_create(struct sasl_server_req_ctx *rctx,
        rctx->request = req;
 }
 
-void sasl_server_request_destroy(struct sasl_server_req_ctx *rctx)
+void sasl_server_mech_request_ref(struct sasl_server_mech_request *mreq)
 {
-       struct sasl_server_request *req = rctx->request;
+       i_assert(mreq->req->refcount > 0);
+       mreq->req->refcount++;
+}
 
-       i_zero(rctx);
-       if (req == NULL)
+void sasl_server_mech_request_unref(struct sasl_server_mech_request **_mreq)
+{
+       struct sasl_server_mech_request *mreq = *_mreq;
+
+       *_mreq = NULL;
+       if (mreq == NULL)
+               return;
+
+       struct sasl_server_request *req = mreq->req;
+
+       i_assert(req->refcount > 0);
+       if (--req->refcount > 0)
                return;
 
        struct sasl_server_instance *sinst = req->sinst;
        struct sasl_server *server = sinst->server;
-       struct sasl_server_mech_request *mreq = req->mech;
 
        i_assert(sinst->requests > 0);
        sinst->requests--;
@@ -74,6 +87,41 @@ void sasl_server_request_destroy(struct sasl_server_req_ctx *rctx)
 
        if (mreq->mech->def->funcs->auth_free != NULL)
                mreq->mech->def->funcs->auth_free(mreq);
+
+       if (req->rctx != NULL)
+               i_zero(req->rctx);
+       pool_unref(&req->pool);
+}
+
+void sasl_server_request_ref(struct sasl_server_req_ctx *rctx)
+{
+       sasl_server_mech_request_ref(rctx->request->mech);
+}
+
+void sasl_server_request_unref(struct sasl_server_req_ctx *rctx)
+{
+       struct sasl_server_request *req = rctx->request;
+
+       i_zero(rctx);
+       if (req == NULL)
+               return;
+
+       struct sasl_server_mech_request *mreq = req->mech;
+
+       sasl_server_mech_request_unref(&mreq);
+}
+
+void sasl_server_request_destroy(struct sasl_server_req_ctx *rctx)
+{
+       struct sasl_server_request *req = rctx->request;
+
+       if (req == NULL) {
+               i_zero(rctx);
+               return;
+       }
+
+       req->rctx = NULL;
+       sasl_server_request_unref(rctx);
 }
 
 static bool
@@ -102,8 +150,10 @@ void sasl_server_request_initial(struct sasl_server_req_ctx *rctx,
        if (sasl_server_request_fail_on_nuls(req, data, data_size))
                return;
 
+       sasl_server_mech_request_ref(mreq);
        i_assert(mech->def->funcs->auth_initial != NULL);
        mech->def->funcs->auth_initial(mreq, data, data_size);
+       sasl_server_mech_request_unref(&mreq);
 }
 
 void sasl_server_request_input(struct sasl_server_req_ctx *rctx,
@@ -116,8 +166,10 @@ void sasl_server_request_input(struct sasl_server_req_ctx *rctx,
        if (sasl_server_request_fail_on_nuls(req, data, data_size))
                return;
 
+       sasl_server_mech_request_ref(mreq);
        i_assert(mech->def->funcs->auth_continue != NULL);
        mech->def->funcs->auth_continue(mreq, data, data_size);
+       sasl_server_mech_request_unref(&mreq);
 }
 
 void sasl_server_request_test_set_authid(struct sasl_server_req_ctx *rctx,
@@ -142,6 +194,7 @@ bool sasl_server_request_set_authid(struct sasl_server_mech_request *mreq,
 
        mreq->authid = p_strdup(req->pool, authid);
 
+       i_assert(req->rctx != NULL);
        i_assert(funcs->request_set_authid != NULL);
        return funcs->request_set_authid(req->rctx, authid_type, authid);
 }
@@ -153,6 +206,7 @@ bool sasl_server_request_set_authzid(struct sasl_server_mech_request *mreq,
        struct sasl_server *server = req->sinst->server;
        const struct sasl_server_request_funcs *funcs = server->funcs;
 
+       i_assert(req->rctx != NULL);
        i_assert(funcs->request_set_authzid != NULL);
        return funcs->request_set_authzid(req->rctx, authzid);
 }
@@ -167,6 +221,7 @@ void sasl_server_request_set_realm(struct sasl_server_mech_request *mreq,
        i_assert(mreq->realm == NULL);
        mreq->realm = p_strdup(req->pool, realm);
 
+       i_assert(req->rctx != NULL);
        i_assert(funcs->request_set_realm != NULL);
        funcs->request_set_realm(req->rctx, realm);
 }
@@ -179,6 +234,7 @@ bool sasl_server_request_get_extra_field(struct sasl_server_mech_request *mreq,
        struct sasl_server *server = req->sinst->server;
        const struct sasl_server_request_funcs *funcs = server->funcs;
 
+       i_assert(req->rctx != NULL);
        if (funcs->request_get_extra_field == NULL) {
                *field_r = NULL;
                return FALSE;
@@ -193,6 +249,7 @@ void sasl_server_request_start_channel_binding(
        struct sasl_server *server = req->sinst->server;
        const struct sasl_server_request_funcs *funcs = server->funcs;
 
+       i_assert(req->rctx != NULL);
        i_assert(funcs->request_start_channel_binding != NULL);
        funcs->request_start_channel_binding(req->rctx, type);
 }
@@ -204,6 +261,7 @@ int sasl_server_request_accept_channel_binding(
        struct sasl_server *server = req->sinst->server;
        const struct sasl_server_request_funcs *funcs = server->funcs;
 
+       i_assert(req->rctx != NULL);
        i_assert(funcs->request_accept_channel_binding != NULL);
        return funcs->request_accept_channel_binding(req->rctx, data_r);
 }
@@ -215,6 +273,8 @@ void sasl_server_request_output(struct sasl_server_mech_request *mreq,
        struct sasl_server *server = req->sinst->server;
        const struct sasl_server_request_funcs *funcs = server->funcs;
 
+       i_assert(req->rctx != NULL);
+
        const struct sasl_server_output output = {
                .status = SASL_SERVER_OUTPUT_CONTINUE,
                .data = data,
@@ -231,6 +291,8 @@ void sasl_server_request_success(struct sasl_server_mech_request *mreq,
        struct sasl_server *server = req->sinst->server;
        const struct sasl_server_request_funcs *funcs = server->funcs;
 
+       i_assert(req->rctx != NULL);
+
        const struct sasl_server_output output = {
                .status = SASL_SERVER_OUTPUT_SUCCESS,
                .data = data,
@@ -249,6 +311,8 @@ sasl_server_request_failure_common(struct sasl_server_mech_request *mreq,
        struct sasl_server *server = req->sinst->server;
        const struct sasl_server_request_funcs *funcs = server->funcs;
 
+       i_assert(req->rctx != NULL);
+
        const struct sasl_server_output output = {
                .status = status,
                .data = data,
@@ -297,6 +361,8 @@ void sasl_server_request_verify_plain(
        struct sasl_server *server = req->sinst->server;
        const struct sasl_server_request_funcs *funcs = server->funcs;
 
+       i_assert(req->rctx != NULL);
+
        req->passdb_type = SASL_SERVER_PASSDB_TYPE_VERIFY_PLAIN;
        req->passdb_callback = callback;
 
@@ -324,6 +390,8 @@ void sasl_server_request_lookup_credentials(
        struct sasl_server *server = req->sinst->server;
        const struct sasl_server_request_funcs *funcs = server->funcs;
 
+       i_assert(req->rctx != NULL);
+
        req->passdb_type = SASL_SERVER_PASSDB_TYPE_LOOKUP_CREDENTIALS;
        req->passdb_callback = callback;
 
@@ -351,6 +419,8 @@ void sasl_server_request_set_credentials(
        struct sasl_server *server = req->sinst->server;
        const struct sasl_server_request_funcs *funcs = server->funcs;
 
+       i_assert(req->rctx != NULL);
+
        req->passdb_type = SASL_SERVER_PASSDB_TYPE_SET_CREDENTIALS;
        req->passdb_callback = callback;
 
index d541deea8f88dd0d54222b7d5d7e2987c6e20ef3..97ddf25fae4e41cf7b9ba5873ce8a1c410ebce15 100644 (file)
@@ -134,6 +134,8 @@ void sasl_server_request_create(struct sasl_server_req_ctx *rctx,
                                const struct sasl_server_mech *mech,
                                const char *protocol,
                                struct event *event_parent);
+void sasl_server_request_ref(struct sasl_server_req_ctx *rctx);
+void sasl_server_request_unref(struct sasl_server_req_ctx *rctx);
 void sasl_server_request_destroy(struct sasl_server_req_ctx *rctx);
 
 void sasl_server_request_initial(struct sasl_server_req_ctx *rctx,