]> git.ipfire.org Git - thirdparty/systemd.git/commitdiff
varlink: Add new varlink_collect method
authorArthur Shau <arthurshau@meta.com>
Wed, 20 Sep 2023 01:52:39 +0000 (18:52 -0700)
committerDaan De Meyer <daan.j.demeyer@gmail.com>
Thu, 19 Oct 2023 09:52:09 +0000 (11:52 +0200)
varlink_collect is meant to be used when the client is willing to wait for the reply from the varlink method, much like varlink_call.
However, unlike varlink_call, it allows the client to collect all "more" replies that may be sent by method before the "final" reply is enqueued.
It aggregates all of these replies into a json variant array that it returns to the client.

src/shared/varlink.c
src/shared/varlink.h
src/test/test-varlink.c

index ae6e142b650ce4a01830a5166f324fd3b810d82c..683ceb314a7531bf26fe9d208093e3ff2486bd95 100644 (file)
@@ -240,6 +240,12 @@ struct VarlinkServer {
         bool exit_on_idle;
 };
 
+typedef struct VarlinkCollectContext {
+        JsonVariant *parameters;
+        const char *error_id;
+        VarlinkReplyFlags flags;
+} VarlinkCollectContext ;
+
 static const char* const varlink_state_table[_VARLINK_STATE_MAX] = {
         [VARLINK_IDLE_CLIENT]              = "idle-client",
         [VARLINK_AWAITING_REPLY]           = "awaiting-reply",
@@ -2088,6 +2094,144 @@ int varlink_callb(
         return varlink_call(v, method, parameters, ret_parameters, ret_error_id, ret_flags);
 }
 
+static void varlink_collect_context_free(VarlinkCollectContext *cc) {
+        assert(cc);
+
+        json_variant_unref(cc->parameters);
+        free((char *)cc->error_id);
+}
+
+static int collect_callback(
+                Varlink *v,
+                JsonVariant *parameters,
+                const char *error_id,
+                VarlinkReplyFlags flags,
+                void *userdata) {
+
+        VarlinkCollectContext *context = ASSERT_PTR(userdata);
+        int r;
+
+        assert(v);
+
+        context->flags = flags;
+        /* If we hit an error, we will drop all collected replies and just return the error_id and flags in varlink_collect() */
+        if (error_id) {
+                context->error_id = error_id;
+                return 0;
+        }
+
+        r = json_variant_append_array(&context->parameters, parameters);
+        if (r < 0)
+                return varlink_log_errno(v, r, "Failed to append JSON object to array: %m");
+
+        return 1;
+}
+
+int varlink_collect(
+                Varlink *v,
+                const char *method,
+                JsonVariant *parameters,
+                JsonVariant **ret_parameters,
+                const char **ret_error_id,
+                VarlinkReplyFlags *ret_flags) {
+
+        _cleanup_(varlink_collect_context_free) VarlinkCollectContext context = {};
+        int r;
+
+        assert_return(v, -EINVAL);
+        assert_return(method, -EINVAL);
+
+        if (v->state == VARLINK_DISCONNECTED)
+                return varlink_log_errno(v, SYNTHETIC_ERRNO(ENOTCONN), "Not connected.");
+        if (v->state != VARLINK_IDLE_CLIENT)
+                return varlink_log_errno(v, SYNTHETIC_ERRNO(EBUSY), "Connection busy.");
+
+        assert(v->n_pending == 0); /* n_pending can't be > 0 if we are in VARLINK_IDLE_CLIENT state */
+
+        /* If there was still a reply pinned from a previous call, now it's the time to get rid of it, so
+         * that we can assign a new reply shortly. */
+        varlink_clear_current(v);
+
+        r = varlink_bind_reply(v, collect_callback);
+        if (r < 0)
+                return varlink_log_errno(v, r, "Failed to bind collect callback");
+
+        varlink_set_userdata(v, &context);
+        r = varlink_observe(v, method, parameters);
+        if (r < 0)
+                return varlink_log_errno(v, r, "Failed to collect varlink method: %m");
+
+        while (v->state == VARLINK_AWAITING_REPLY_MORE) {
+
+                r = varlink_process(v);
+                if (r < 0)
+                        return r;
+
+                /* If we get an error from any of the replies, return immediately with just the error_id and flags*/
+                if (context.error_id) {
+                        if (ret_error_id)
+                                *ret_error_id = TAKE_PTR(context.error_id);
+                        if (ret_flags)
+                                *ret_flags = context.flags;
+                        return 0;
+                }
+
+                if (r > 0)
+                        continue;
+
+                r = varlink_wait(v, USEC_INFINITY);
+                if (r < 0)
+                        return r;
+        }
+
+        switch (v->state) {
+
+        case VARLINK_IDLE_CLIENT:
+                break;
+
+        case VARLINK_PENDING_DISCONNECT:
+        case VARLINK_DISCONNECTED:
+                return varlink_log_errno(v, SYNTHETIC_ERRNO(ECONNRESET), "Connection was closed.");
+
+        case VARLINK_PENDING_TIMEOUT:
+                return varlink_log_errno(v, SYNTHETIC_ERRNO(ETIME), "Connection timed out.");
+
+        default:
+                assert_not_reached();
+        }
+
+        if (ret_parameters)
+                *ret_parameters = TAKE_PTR(context.parameters);
+        if (ret_error_id)
+                *ret_error_id = TAKE_PTR(context.error_id);
+        if (ret_flags)
+                *ret_flags = context.flags;
+        return 1;
+}
+
+int varlink_collectb(
+                Varlink *v,
+                const char *method,
+                JsonVariant **ret_parameters,
+                const char **ret_error_id,
+                VarlinkReplyFlags *ret_flags, ...) {
+
+        _cleanup_(json_variant_unrefp) JsonVariant *parameters = NULL;
+        va_list ap;
+        int r;
+
+        assert_return(v, -EINVAL);
+
+        va_start(ap, ret_flags);
+        r = json_buildv(&parameters, ap);
+        va_end(ap);
+
+        if (r < 0)
+                return varlink_log_errno(v, r, "Failed to build json message: %m");
+
+        return varlink_collect(v, method, parameters, ret_parameters, ret_error_id, ret_flags);
+}
+
 int varlink_reply(Varlink *v, JsonVariant *parameters) {
         _cleanup_(json_variant_unrefp) JsonVariant *m = NULL;
         int r;
index e5541d3ce5f794173650935d6a496ef526e43633..516d3b5a906f5812d088f668f4507b1ad0a9a077 100644 (file)
@@ -89,6 +89,10 @@ int varlink_sendb(Varlink *v, const char *method, ...);
 int varlink_call(Varlink *v, const char *method, JsonVariant *parameters, JsonVariant **ret_parameters, const char **ret_error_id, VarlinkReplyFlags *ret_flags);
 int varlink_callb(Varlink *v, const char *method, JsonVariant **ret_parameters, const char **ret_error_id, VarlinkReplyFlags *ret_flags, ...);
 
+/* Send method call and begin collecting all 'more' replies into an array, finishing when a final reply is sent */
+int varlink_collect(Varlink *v, const char *method, JsonVariant *parameters, JsonVariant **ret_parameters, const char **ret_error_id, VarlinkReplyFlags *ret_flags);
+int varlink_collectb(Varlink *v, const char *method, JsonVariant **ret_parameters, const char **ret_error_id, VarlinkReplyFlags *ret_flags, ...);
+
 /* Enqueue method call, expect a reply, which is eventually delivered to the reply callback */
 int varlink_invoke(Varlink *v, const char *method, JsonVariant *parameters);
 int varlink_invokeb(Varlink *v, const char *method, ...);
index 631305e8f0f0784bfdb5458f83e8c0c82e02f780..8359ab18921eb54eb57e6532eed5dee49192d892 100644 (file)
@@ -49,6 +49,43 @@ static int method_something(Varlink *link, JsonVariant *parameters, VarlinkMetho
         return varlink_reply(link, ret);
 }
 
+static int method_something_more(Varlink *link, JsonVariant *parameters, VarlinkMethodFlags flags, void *userdata) {
+        _cleanup_(json_variant_unrefp) JsonVariant *ret = NULL;
+        int r;
+
+        struct Something {
+                int x;
+                int y;
+        };
+
+        static const JsonDispatch dispatch_table[] = {
+                { "a",  JSON_VARIANT_INTEGER, json_dispatch_int, offsetof(struct Something, x),  JSON_MANDATORY },
+                { "b", JSON_VARIANT_INTEGER, json_dispatch_int, offsetof(struct Something, y), JSON_MANDATORY},
+                {}
+        };
+        struct Something s = {};
+
+        r = json_dispatch(parameters, dispatch_table, NULL, 0, &s);
+
+        for (int i = 0; i < 5; i++) {
+                _cleanup_(json_variant_unrefp) JsonVariant *w = NULL;
+
+                r = json_build(&w, JSON_BUILD_OBJECT(JSON_BUILD_PAIR("sum", JSON_BUILD_INTEGER(s.x + (s.y * i)))));
+                if (r < 0)
+                        return r;
+
+                r = varlink_notify(link, w);
+                if (r < 0)
+                        return r;
+        }
+
+        r = json_build(&ret, JSON_BUILD_OBJECT(JSON_BUILD_PAIR("sum", JSON_BUILD_INTEGER(s.x + (s.y * 5)))));
+        if (r < 0)
+                return r;
+
+        return varlink_reply(link, ret);
+}
+
 static void test_fd(int fd, const void *buf, size_t n) {
         char rbuf[n + 1];
         ssize_t m;
@@ -198,9 +235,12 @@ static void flood_test(const char *address) {
 
 static void *thread(void *arg) {
         _cleanup_(varlink_flush_close_unrefp) Varlink *c = NULL;
-        _cleanup_(json_variant_unrefp) JsonVariant *i = NULL;
-        JsonVariant *o = NULL;
+        _cleanup_(json_variant_unrefp) JsonVariant *i = NULL, *j = NULL;
+        JsonVariant *o = NULL, *k = NULL;
+        const char *error_id;
+        VarlinkReplyFlags flags = 0;
         const char *e;
+        int x = 0;
 
         assert_se(json_build(&i, JSON_BUILD_OBJECT(JSON_BUILD_PAIR("a", JSON_BUILD_INTEGER(88)),
                                                    JSON_BUILD_PAIR("b", JSON_BUILD_INTEGER(99)))) >= 0);
@@ -210,6 +250,18 @@ static void *thread(void *arg) {
         assert_se(varlink_set_allow_fd_passing_input(c, true) >= 0);
         assert_se(varlink_set_allow_fd_passing_output(c, true) >= 0);
 
+        assert_se(varlink_collect(c, "io.test.DoSomethingMore", i, &j, &error_id, &flags) >= 0);
+
+        assert_se(!error_id);
+        assert_se(!flags);
+        assert_se(json_variant_is_array(j) && !json_variant_is_blank_array(j));
+
+        JSON_VARIANT_ARRAY_FOREACH(k, j) {
+                assert_se(json_variant_integer(json_variant_by_key(k, "sum")) == 88 + (99 * x));
+                x++;
+        }
+        assert_se(x == 6);
+
         assert_se(varlink_call(c, "io.test.DoSomething", i, &o, &e, NULL) >= 0);
         assert_se(json_variant_integer(json_variant_by_key(o, "sum")) == 88 + 99);
         assert_se(!e);
@@ -294,19 +346,20 @@ int main(int argc, char *argv[]) {
 
         assert_se(varlink_server_bind_method(s, "io.test.PassFD", method_passfd) >= 0);
         assert_se(varlink_server_bind_method(s, "io.test.DoSomething", method_something) >= 0);
+        assert_se(varlink_server_bind_method(s, "io.test.DoSomethingMore", method_something_more) >= 0);
         assert_se(varlink_server_bind_method(s, "io.test.Done", method_done) >= 0);
         assert_se(varlink_server_bind_connect(s, on_connect) >= 0);
         assert_se(varlink_server_listen_address(s, sp, 0600) >= 0);
         assert_se(varlink_server_attach_event(s, e, 0) >= 0);
         assert_se(varlink_server_set_connections_max(s, OVERLOAD_CONNECTIONS) >= 0);
 
+        assert_se(json_build(&v, JSON_BUILD_OBJECT(JSON_BUILD_PAIR("a", JSON_BUILD_INTEGER(7)),
+                                                   JSON_BUILD_PAIR("b", JSON_BUILD_INTEGER(22)))) >= 0);
+
         assert_se(varlink_connect_address(&c, sp) >= 0);
         assert_se(varlink_set_description(c, "main-client") >= 0);
         assert_se(varlink_bind_reply(c, reply) >= 0);
 
-        assert_se(json_build(&v, JSON_BUILD_OBJECT(JSON_BUILD_PAIR("a", JSON_BUILD_INTEGER(7)),
-                                                   JSON_BUILD_PAIR("b", JSON_BUILD_INTEGER(22)))) >= 0);
-
         assert_se(varlink_invoke(c, "io.test.DoSomething", v) >= 0);
 
         assert_se(varlink_attach_event(c, e, 0) >= 0);