]> git.ipfire.org Git - thirdparty/systemd.git/commitdiff
varlink: rework varlink_collect()
authorLennart Poettering <lennart@poettering.net>
Thu, 8 Feb 2024 10:34:49 +0000 (11:34 +0100)
committerLennart Poettering <lennart@poettering.net>
Mon, 12 Feb 2024 11:04:18 +0000 (12:04 +0100)
This reworks varlink_collect() so that it is not just a wrapper around
varlink_observe(), varlink_bind_reply() and others. It becomes a first
class operation.

This has various benefits:

1. Memory management is normalized: the reply json variant is now
   tracked as part of the varlink object, and thus we do not pass
   ownership to the caller. This is just like we do it for simple method
   calls and removes a lot of confusion.
2. The bind reply/user data pointer can be used for user stuff, we'll
   not silently override this.
3. We enforce an overall time-out operation on the whole thing, so that
   this synchronous operation does no longer block forever.

src/shared/varlink.c
src/shared/varlink.h
src/test/test-varlink.c

index 10d4da1a79d4b1e16a543393d5c34fa2bee9c4b0..8a5a9907a8aa3f025080255f5049bf5425337695 100644 (file)
@@ -46,6 +46,8 @@ typedef enum VarlinkState {
         VARLINK_AWAITING_REPLY_MORE,
         VARLINK_CALLING,
         VARLINK_CALLED,
+        VARLINK_COLLECTING,
+        VARLINK_COLLECTING_REPLY,
         VARLINK_PROCESSING_REPLY,
 
         /* Server side states */
@@ -80,6 +82,8 @@ typedef enum VarlinkState {
                VARLINK_AWAITING_REPLY_MORE,             \
                VARLINK_CALLING,                         \
                VARLINK_CALLED,                          \
+               VARLINK_COLLECTING,                      \
+               VARLINK_COLLECTING_REPLY,                \
                VARLINK_PROCESSING_REPLY,                \
                VARLINK_IDLE_SERVER,                     \
                VARLINK_PROCESSING_METHOD,               \
@@ -170,6 +174,7 @@ struct Varlink {
         VarlinkReply reply_callback;
 
         JsonVariant *current;
+        JsonVariant *current_collected;
         VarlinkReplyFlags current_reply_flags;
         VarlinkSymbol *current_method;
 
@@ -245,18 +250,14 @@ struct VarlinkServer {
         bool exit_on_idle;
 };
 
-typedef struct VarlinkCollectContext {
-        JsonVariant *parameters;
-        const char *error_id;
-        VarlinkReplyFlags flags;
-} VarlinkCollectContext ;
-
 static const char* const varlink_state_table[_VARLINK_STATE_MAX] = {
         [VARLINK_IDLE_CLIENT]              = "idle-client",
         [VARLINK_AWAITING_REPLY]           = "awaiting-reply",
         [VARLINK_AWAITING_REPLY_MORE]      = "awaiting-reply-more",
         [VARLINK_CALLING]                  = "calling",
         [VARLINK_CALLED]                   = "called",
+        [VARLINK_COLLECTING]               = "collecting",
+        [VARLINK_COLLECTING_REPLY]         = "collecting-reply",
         [VARLINK_PROCESSING_REPLY]         = "processing-reply",
         [VARLINK_IDLE_SERVER]              = "idle-server",
         [VARLINK_PROCESSING_METHOD]        = "processing-method",
@@ -690,6 +691,7 @@ static void varlink_clear_current(Varlink *v) {
 
         /* Clears the currently processed incoming message */
         v->current = json_variant_unref(v->current);
+        v->current_collected = json_variant_unref(v->current_collected);
         v->current_method = NULL;
         v->current_reply_flags = 0;
 
@@ -774,7 +776,7 @@ static int varlink_test_disconnect(Varlink *v) {
                 goto disconnect;
 
         /* If we are waiting for incoming data but the read side is shut down, disconnect. */
-        if (IN_SET(v->state, VARLINK_AWAITING_REPLY, VARLINK_AWAITING_REPLY_MORE, VARLINK_CALLING, VARLINK_IDLE_SERVER) && v->read_disconnected)
+        if (IN_SET(v->state, VARLINK_AWAITING_REPLY, VARLINK_AWAITING_REPLY_MORE, VARLINK_CALLING, VARLINK_COLLECTING, VARLINK_IDLE_SERVER) && v->read_disconnected)
                 goto disconnect;
 
         /* Similar, if are a client that hasn't written anything yet but the write side is dead, also
@@ -897,7 +899,7 @@ static int varlink_read(Varlink *v) {
 
         assert(v);
 
-        if (!IN_SET(v->state, VARLINK_AWAITING_REPLY, VARLINK_AWAITING_REPLY_MORE, VARLINK_CALLING, VARLINK_IDLE_SERVER))
+        if (!IN_SET(v->state, VARLINK_AWAITING_REPLY, VARLINK_AWAITING_REPLY_MORE, VARLINK_CALLING, VARLINK_COLLECTING, VARLINK_IDLE_SERVER))
                 return 0;
         if (v->connecting) /* read() on a socket while we are in connect() will fail with EINVAL, hence exit early here */
                 return 0;
@@ -1093,7 +1095,7 @@ static int varlink_parse_message(Varlink *v) {
 static int varlink_test_timeout(Varlink *v) {
         assert(v);
 
-        if (!IN_SET(v->state, VARLINK_AWAITING_REPLY, VARLINK_AWAITING_REPLY_MORE, VARLINK_CALLING))
+        if (!IN_SET(v->state, VARLINK_AWAITING_REPLY, VARLINK_AWAITING_REPLY_MORE, VARLINK_CALLING, VARLINK_COLLECTING))
                 return 0;
         if (v->timeout == USEC_INFINITY)
                 return 0;
@@ -1183,7 +1185,7 @@ static int varlink_dispatch_reply(Varlink *v) {
 
         assert(v);
 
-        if (!IN_SET(v->state, VARLINK_AWAITING_REPLY, VARLINK_AWAITING_REPLY_MORE, VARLINK_CALLING))
+        if (!IN_SET(v->state, VARLINK_AWAITING_REPLY, VARLINK_AWAITING_REPLY_MORE, VARLINK_CALLING, VARLINK_COLLECTING))
                 return 0;
         if (!v->current)
                 return 0;
@@ -1226,7 +1228,7 @@ static int varlink_dispatch_reply(Varlink *v) {
         }
 
         /* Replies with 'continue' set are only OK if we set 'more' when the method call was initiated */
-        if (v->state != VARLINK_AWAITING_REPLY_MORE && FLAGS_SET(flags, VARLINK_REPLY_CONTINUES))
+        if (!IN_SET(v->state, VARLINK_AWAITING_REPLY_MORE, VARLINK_COLLECTING) && FLAGS_SET(flags, VARLINK_REPLY_CONTINUES))
                 goto invalid;
 
         /* An error is final */
@@ -1260,7 +1262,9 @@ static int varlink_dispatch_reply(Varlink *v) {
                                           FLAGS_SET(flags, VARLINK_REPLY_CONTINUES) ? VARLINK_AWAITING_REPLY_MORE :
                                           v->n_pending == 0 ? VARLINK_IDLE_CLIENT : VARLINK_AWAITING_REPLY);
                 }
-        } else {
+        } else if (v->state == VARLINK_COLLECTING)
+                varlink_set_state(v, VARLINK_COLLECTING_REPLY);
+        else {
                 assert(v->state == VARLINK_CALLING);
                 varlink_set_state(v, VARLINK_CALLED);
         }
@@ -1738,7 +1742,7 @@ int varlink_get_events(Varlink *v) {
                 return EPOLLOUT;
 
         if (!v->read_disconnected &&
-            IN_SET(v->state, VARLINK_AWAITING_REPLY, VARLINK_AWAITING_REPLY_MORE, VARLINK_CALLING, VARLINK_IDLE_SERVER) &&
+            IN_SET(v->state, VARLINK_AWAITING_REPLY, VARLINK_AWAITING_REPLY_MORE, VARLINK_CALLING, VARLINK_COLLECTING, VARLINK_IDLE_SERVER) &&
             !v->current &&
             v->input_buffer_unscanned <= 0)
                 ret |= EPOLLIN;
@@ -1756,7 +1760,7 @@ int varlink_get_timeout(Varlink *v, usec_t *ret) {
         if (v->state == VARLINK_DISCONNECTED)
                 return varlink_log_errno(v, SYNTHETIC_ERRNO(ENOTCONN), "Not connected.");
 
-        if (IN_SET(v->state, VARLINK_AWAITING_REPLY, VARLINK_AWAITING_REPLY_MORE, VARLINK_CALLING) &&
+        if (IN_SET(v->state, VARLINK_AWAITING_REPLY, VARLINK_AWAITING_REPLY_MORE, VARLINK_CALLING, VARLINK_COLLECTING) &&
             v->timeout != USEC_INFINITY) {
                 if (ret)
                         *ret = usec_add(v->timestamp, v->timeout);
@@ -2321,47 +2325,7 @@ int varlink_callb_and_log(
         return varlink_call_and_log(v, method, parameters, ret_parameters);
 }
 
-static void varlink_collect_context_free(VarlinkCollectContext *cc) {
-        assert(cc);
-
-        json_variant_unref(cc->parameters);
-        free((char *)cc->error_id);
-}
-
-static int collect_callback(
-                Varlink *v,
-                JsonVariant *parameters,
-                const char *error_id,
-                VarlinkReplyFlags flags,
-                void *userdata) {
-
-        VarlinkCollectContext *context = ASSERT_PTR(userdata);
-        int r;
-
-        assert(v);
-
-        context->flags = flags;
-        /* If we hit an error, we will drop all collected replies and just return the error_id and flags in varlink_collect() */
-        if (error_id) {
-                context->error_id = error_id;
-
-                json_variant_unref(context->parameters);
-                context->parameters = json_variant_ref(parameters);
-
-                return 0;
-        }
-
-        if (json_variant_elements(context->parameters) >= VARLINK_COLLECT_MAX)
-                return varlink_log_errno(v, SYNTHETIC_ERRNO(E2BIG), "Number of reply messages grew too large (%zu) while collecting.", json_variant_elements(context->parameters));
-
-        r = json_variant_append_array(&context->parameters, parameters);
-        if (r < 0)
-                return varlink_log_errno(v, r, "Failed to append JSON object to array: %m");
-
-        return 1;
-}
-
-int varlink_collect(
+int varlink_collect_full(
                 Varlink *v,
                 const char *method,
                 JsonVariant *parameters,
@@ -2369,7 +2333,7 @@ int varlink_collect(
                 const char **ret_error_id,
                 VarlinkReplyFlags *ret_flags) {
 
-        _cleanup_(varlink_collect_context_free) VarlinkCollectContext context = {};
+        _cleanup_(json_variant_unrefp) JsonVariant *m = NULL, *collected = NULL;
         int r;
 
         assert_return(v, -EINVAL);
@@ -2386,71 +2350,102 @@ int varlink_collect(
          * that we can assign a new reply shortly. */
         varlink_clear_current(v);
 
-        r = varlink_bind_reply(v, collect_callback);
+        r = varlink_sanitize_parameters(&parameters);
         if (r < 0)
-                return varlink_log_errno(v, r, "Failed to bind collect callback");
+                return varlink_log_errno(v, r, "Failed to sanitize parameters: %m");
 
-        varlink_set_userdata(v, &context);
-        r = varlink_observe(v, method, parameters);
+        r = json_build(&m, JSON_BUILD_OBJECT(
+                                       JSON_BUILD_PAIR("method", JSON_BUILD_STRING(method)),
+                                       JSON_BUILD_PAIR("parameters", JSON_BUILD_VARIANT(parameters)),
+                                       JSON_BUILD_PAIR("more", JSON_BUILD_BOOLEAN(true))));
         if (r < 0)
-                return varlink_log_errno(v, r, "Failed to collect varlink method: %m");
-
-        while (v->state == VARLINK_AWAITING_REPLY_MORE) {
+                return varlink_log_errno(v, r, "Failed to build json message: %m");
 
-                r = varlink_process(v);
-                if (r < 0)
-                        return r;
+        r = varlink_enqueue_json(v, m);
+        if (r < 0)
+                return varlink_log_errno(v, r, "Failed to enqueue json message: %m");
 
-                /* If we get an error from any of the replies, return immediately with just the error_id and flags*/
-                if (context.error_id) {
+        varlink_set_state(v, VARLINK_COLLECTING);
+        v->n_pending++;
+        v->timestamp = now(CLOCK_MONOTONIC);
 
-                        /* If caller doesn't ask for the error string, then let's return an error code in case of failure */
-                        if (!ret_error_id)
-                                return varlink_error_to_errno(context.error_id, context.parameters);
+        for (;;) {
+                while (v->state == VARLINK_COLLECTING) {
+                        r = varlink_process(v);
+                        if (r < 0)
+                                return r;
+                        if (r > 0)
+                                continue;
 
-                        if (ret_parameters)
-                                *ret_parameters = TAKE_PTR(context.parameters);
-                        if (ret_error_id)
-                                *ret_error_id = TAKE_PTR(context.error_id);
-                        if (ret_flags)
-                                *ret_flags = context.flags;
-                        return 0;
+                        r = varlink_wait(v, USEC_INFINITY);
+                        if (r < 0)
+                                return r;
                 }
 
-                if (r > 0)
-                        continue;
+                switch (v->state) {
 
-                r = varlink_wait(v, USEC_INFINITY);
-                if (r < 0)
-                        return r;
-        }
+                case VARLINK_COLLECTING_REPLY: {
+                        assert(v->current);
 
-        switch (v->state) {
+                        JsonVariant *e = json_variant_by_key(v->current, "error"),
+                                *p = json_variant_by_key(v->current, "parameters");
 
-        case VARLINK_IDLE_CLIENT:
-                break;
+                        if (e) {
+                                if (!ret_error_id)
+                                        return varlink_error_to_errno(json_variant_string(e), p);
 
-        case VARLINK_PENDING_DISCONNECT:
-        case VARLINK_DISCONNECTED:
-                return varlink_log_errno(v, SYNTHETIC_ERRNO(ECONNRESET), "Connection was closed.");
+                                if (ret_parameters)
+                                        *ret_parameters = p;
+                                if (ret_error_id)
+                                        *ret_error_id = e ? json_variant_string(e) : NULL;
+                                if (ret_flags)
+                                        *ret_flags = v->current_reply_flags;
 
-        case VARLINK_PENDING_TIMEOUT:
-                return varlink_log_errno(v, SYNTHETIC_ERRNO(ETIME), "Connection timed out.");
+                                return 1;
+                        }
 
-        default:
-                assert_not_reached();
-        }
+                        if (json_variant_elements(collected) >= VARLINK_COLLECT_MAX)
+                                return varlink_log_errno(v, SYNTHETIC_ERRNO(E2BIG), "Number of reply messages grew too large (%zu) while collecting.", json_variant_elements(collected));
 
-        if (!ret_error_id && context.error_id)
-                return varlink_error_to_errno(context.error_id, context.parameters);
+                        r = json_variant_append_array(&collected, p);
+                        if (r < 0)
+                                return varlink_log_errno(v, r, "Failed to append JSON object to array: %m");
 
-        if (ret_parameters)
-                *ret_parameters = TAKE_PTR(context.parameters);
-        if (ret_error_id)
-                *ret_error_id = TAKE_PTR(context.error_id);
-        if (ret_flags)
-                *ret_flags = context.flags;
-        return 1;
+                        if (FLAGS_SET(v->current_reply_flags, VARLINK_REPLY_CONTINUES)) {
+                                /* There's more to collect, continue */
+                                varlink_clear_current(v);
+                                varlink_set_state(v, VARLINK_COLLECTING);
+                                continue;
+                        }
+
+                        varlink_set_state(v, VARLINK_IDLE_CLIENT);
+                        assert(v->n_pending == 1);
+                        v->n_pending--;
+
+                        if (ret_parameters)
+                                /* Install the collection array in the connection object, so that we can hand
+                                 * out a pointer to it without passing over ownership, to make it work more
+                                 * alike regular method call replies */
+                                *ret_parameters = v->current_collected = TAKE_PTR(collected);
+                        if (ret_error_id)
+                                *ret_error_id = NULL;
+                        if (ret_flags)
+                                *ret_flags = v->current_reply_flags;
+
+                        return 1;
+                }
+
+                case VARLINK_PENDING_DISCONNECT:
+                case VARLINK_DISCONNECTED:
+                        return varlink_log_errno(v, SYNTHETIC_ERRNO(ECONNRESET), "Connection was closed.");
+
+                case VARLINK_PENDING_TIMEOUT:
+                        return varlink_log_errno(v, SYNTHETIC_ERRNO(ETIME), "Connection timed out.");
+
+                default:
+                        assert_not_reached();
+                }
+        }
 }
 
 int varlink_collectb(
@@ -2458,7 +2453,7 @@ int varlink_collectb(
                 const char *method,
                 JsonVariant **ret_parameters,
                 const char **ret_error_id,
-                VarlinkReplyFlags *ret_flags, ...) {
+                ...) {
 
         _cleanup_(json_variant_unrefp) JsonVariant *parameters = NULL;
         va_list ap;
@@ -2466,14 +2461,14 @@ int varlink_collectb(
 
         assert_return(v, -EINVAL);
 
-        va_start(ap, ret_flags);
+        va_start(ap, ret_error_id);
         r = json_buildv(&parameters, ap);
         va_end(ap);
 
         if (r < 0)
                 return varlink_log_errno(v, r, "Failed to build json message: %m");
 
-        return varlink_collect(v, method, parameters, ret_parameters, ret_error_id, ret_flags);
+        return varlink_collect_full(v, method, parameters, ret_parameters, ret_error_id, NULL);
 }
 
 int varlink_reply(Varlink *v, JsonVariant *parameters) {
index 622ab797c5a07b62492941a64ec31925f40405c8..db7227b215472a03c439855af8fb26da6f4724c5 100644 (file)
@@ -116,8 +116,11 @@ static inline int varlink_callb(Varlink *v, const char *method, JsonVariant **re
 int varlink_callb_and_log(Varlink *v, const char *method, JsonVariant **ret_parameters, ...);
 
 /* Send method call and begin collecting all 'more' replies into an array, finishing when a final reply is sent */
-int varlink_collect(Varlink *v, const char *method, JsonVariant *parameters, JsonVariant **ret_parameters, const char **ret_error_id, VarlinkReplyFlags *ret_flags);
-int varlink_collectb(Varlink *v, const char *method, JsonVariant **ret_parameters, const char **ret_error_id, VarlinkReplyFlags *ret_flags, ...);
+int varlink_collect_full(Varlink *v, const char *method, JsonVariant *parameters, JsonVariant **ret_parameters, const char **ret_error_id, VarlinkReplyFlags *ret_flags);
+static inline int varlink_collect(Varlink *v, const char *method, JsonVariant *parameters, JsonVariant **ret_parameters, const char **ret_error_id) {
+        return varlink_collect_full(v, method, parameters, ret_parameters, ret_error_id, NULL);
+}
+int varlink_collectb(Varlink *v, const char *method, JsonVariant **ret_parameters, const char **ret_error_id, ...);
 
 /* Enqueue method call, expect a reply, which is eventually delivered to the reply callback */
 int varlink_invoke(Varlink *v, const char *method, JsonVariant *parameters);
index b0b244e9178a123ebcb34430f80f82f60c010996..67ad21300271e4480f74c904bc5c5ad30bdbff01 100644 (file)
@@ -238,10 +238,9 @@ static void flood_test(const char *address) {
 
 static void *thread(void *arg) {
         _cleanup_(varlink_flush_close_unrefp) Varlink *c = NULL;
-        _cleanup_(json_variant_unrefp) JsonVariant *i = NULL, *j = NULL;
-        JsonVariant *o = NULL, *k = NULL;
+        _cleanup_(json_variant_unrefp) JsonVariant *i = NULL;
+        JsonVariant *o = NULL, *k = NULL, *j = NULL;
         const char *error_id;
-        VarlinkReplyFlags flags = 0;
         const char *e;
         int x = 0;
 
@@ -253,10 +252,9 @@ static void *thread(void *arg) {
         assert_se(varlink_set_allow_fd_passing_input(c, true) >= 0);
         assert_se(varlink_set_allow_fd_passing_output(c, true) >= 0);
 
-        assert_se(varlink_collect(c, "io.test.DoSomethingMore", i, &j, &error_id, &flags) >= 0);
+        assert_se(varlink_collect(c, "io.test.DoSomethingMore", i, &j, &error_id) >= 0);
 
         assert_se(!error_id);
-        assert_se(!flags);
         assert_se(json_variant_is_array(j) && !json_variant_is_blank_array(j));
 
         JSON_VARIANT_ARRAY_FOREACH(k, j) {