]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
fixup! WIP: daemon/worker: weak pointer logic for tasks daemon-refactor
authorOto Šťáva <oto.stava@nic.cz>
Fri, 24 Jun 2022 08:56:18 +0000 (10:56 +0200)
committerOto Šťáva <oto.stava@nic.cz>
Fri, 24 Jun 2022 08:56:18 +0000 (10:56 +0200)
daemon/worker.c

index 8f488f736e7edf929124daa6375c8c4763ce8705..6a0a0d16ac5ec45d5f273215d28245528c4a710b 100644 (file)
@@ -649,6 +649,7 @@ static inline void qr_task_assert_weakptr(struct qr_task *task)
 static int qr_task_on_send_internal(struct qr_task *task, const uv_handle_t *handle, int status)
 {
        qr_task_assert_weakptr(task);
+       qr_task_weakptr_t taskptr = task->weakptr;
        if (task->finished) {
                kr_require(task->leading == false);
                qr_task_complete(task);
@@ -709,7 +710,7 @@ static int qr_task_on_send_internal(struct qr_task *task, const uv_handle_t *han
        }
 
 cleanup:;
-       if (task->finished) {
+       if (worker_task_exists(taskptr) && task->finished) {
                /* Answer has been sent or an error has occurred,
                 * the task is complete, we can free it. */
                qr_task_free(task);
@@ -725,6 +726,31 @@ int qr_task_on_send(qr_task_weakptr_t taskptr, const uv_handle_t *handle, int st
        return qr_task_on_send_internal(task, handle, status);
 }
 
+struct async_write_data {
+       uv_req_t *req;
+       qr_task_weakptr_t taskptr;
+       char buf[];
+};
+
+static struct async_write_data *make_req_async(uv_req_t *req, const uv_buf_t bufs[], unsigned int nbufs)
+{
+       size_t offs[nbufs + 1];
+       offs[0] = 0;
+       for (unsigned int i = 0; i < nbufs; i++)
+               offs[i + 1] = offs[i] + bufs[i].len;
+       struct async_write_data *adata = malloc(sizeof(*adata) + offs[nbufs]);
+       kr_require(adata);
+
+       adata->taskptr = (qr_task_weakptr_t)req->data;
+       req->data = adata;
+       adata->req = req;
+
+       for (unsigned int i = 0; i < nbufs; i++)
+               memcpy(adata->buf + offs[i], bufs[i].base, bufs[i].len);
+
+       return adata;
+}
+
 static void on_send(uv_udp_send_t *req, int status)
 {
        qr_task_weakptr_t taskptr = (qr_task_weakptr_t) req->data;
@@ -733,6 +759,37 @@ static void on_send(uv_udp_send_t *req, int status)
        free(req);
 }
 
+static void on_async_send(uv_udp_send_t *req, int status)
+{
+       struct async_write_data *adata = req->data;
+       req->data = (void *)adata->taskptr;
+       on_send(req, status);
+       free(adata);
+}
+
+static int kr_udp_send(uv_udp_send_t *req, uv_udp_t *handle, const uv_buf_t bufs[],
+                       unsigned int nbufs, const struct sockaddr *addr)
+{
+       size_t total_len = 0;
+       for (unsigned int i = 0; i < nbufs; i++)
+               total_len += bufs[i].len;
+       req->handle = handle;
+       int ret = uv_udp_try_send(handle, bufs, nbufs, addr);
+       if (ret == total_len) {
+               on_send(req, 0);
+               return 0;
+       }
+
+       if (ret >= 0)
+               return UV_EIO;
+       if (ret != UV_EAGAIN)
+               return ret;
+
+       struct async_write_data *adata = make_req_async((uv_req_t *)req, bufs, nbufs);
+       uv_buf_t buf = { .base = adata->buf, .len = total_len };
+       return uv_udp_send(req, handle, &buf, 1, addr, &on_async_send);
+}
+
 static void on_write(uv_write_t *req, int status)
 {
        qr_task_weakptr_t taskptr = (qr_task_weakptr_t) req->data;
@@ -741,6 +798,37 @@ static void on_write(uv_write_t *req, int status)
        free(req);
 }
 
+static void on_async_write(uv_write_t *req, int status)
+{
+       struct async_write_data *adata = req->data;
+       req->data = (void *)adata->taskptr;
+       on_write(req, status);
+       free(adata);
+}
+
+static int kr_write(uv_write_t *req, uv_stream_t *handle, const uv_buf_t bufs[],
+                    unsigned int nbufs)
+{
+       size_t total_len = 0;
+       for (unsigned int i = 0; i < nbufs; i++)
+               total_len += bufs[i].len;
+       req->handle = handle;
+       int ret = uv_try_write(handle, bufs, nbufs);
+       if (ret == total_len) {
+               on_write(req, 0);
+               return 0;
+       }
+
+       if (ret >= 0)
+               return UV_EIO;
+       if (ret != UV_EAGAIN)
+               return ret;
+
+       struct async_write_data *adata = make_req_async((uv_req_t *)req, bufs, nbufs);
+       uv_buf_t buf = { .base = adata->buf, .len = total_len };
+       return uv_write(req, handle, &buf, 1, &on_async_write);
+}
+
 static int qr_task_send(struct qr_task *task, struct session *session,
                        const struct sockaddr *addr, knot_pkt_t *pkt)
 {
@@ -805,7 +893,7 @@ static int qr_task_send(struct qr_task *task, struct session *session,
                uv_udp_send_t *send_req = (uv_udp_send_t *)ioreq;
                uv_buf_t buf = { (char *)pkt->wire, pkt->size };
                send_req->data = (void *)task->weakptr;
-               ret = uv_udp_send(send_req, (uv_udp_t *)handle, &buf, 1, addr, &on_send);
+               ret = kr_udp_send(send_req, (uv_udp_t *)handle, &buf, 1, addr);
        } else if (handle->type == UV_TCP) {
                uv_write_t *write_req = (uv_write_t *)ioreq;
                /* We need to write message length in native byte order,
@@ -829,7 +917,7 @@ static int qr_task_send(struct qr_task *task, struct session *session,
                        { (char *)pkt->wire, pkt->size },
                };
                write_req->data = (void *)task->weakptr;
-               ret = uv_write(write_req, (uv_stream_t *)handle, buf, 3, &on_write);
+               ret = kr_write(write_req, (uv_stream_t *)handle, buf, 3);
        } else {
                kr_assert(false);
        }
@@ -860,6 +948,8 @@ static int qr_task_send(struct qr_task *task, struct session *session,
                        the_worker->stats.err_udp += 1;
                else
                        the_worker->stats.err_tcp += 1;
+
+               qr_task_on_send_internal(task, handle, ret);
        }
 
        /* Update outgoing query statistics */