]> git.ipfire.org Git - thirdparty/systemd.git/commitdiff
varlink: implement the org.varlink.service introspection interface by default + hook...
authorLennart Poettering <lennart@poettering.net>
Fri, 22 Sep 2023 20:39:25 +0000 (22:39 +0200)
committerLennart Poettering <lennart@poettering.net>
Fri, 6 Oct 2023 09:49:38 +0000 (11:49 +0200)
Fixes: #23874
src/shared/varlink.c

index 2295f460165cdfdeabd62ea0a26e70d553353812..aac25bbea8b9f856facf0dfec5918781024795ce 100644 (file)
@@ -23,6 +23,9 @@
 #include "user-util.h"
 #include "varlink.h"
 #include "varlink-internal.h"
+#include "varlink-org.varlink.service.h"
+#include "varlink-io.systemd.h"
+#include "version.h"
 
 #define VARLINK_DEFAULT_CONNECTIONS_MAX 4096U
 #define VARLINK_DEFAULT_CONNECTIONS_PER_UID_MAX 1024U
@@ -162,6 +165,7 @@ struct Varlink {
         VarlinkReply reply_callback;
 
         JsonVariant *current;
+        VarlinkSymbol *current_method;
 
         struct ucred ucred;
         bool ucred_acquired:1;
@@ -210,8 +214,9 @@ struct VarlinkServer {
 
         LIST_HEAD(VarlinkServerSocket, sockets);
 
-        Hashmap *methods;
-        Hashmap *interfaces;
+        Hashmap *methods;              /* Fully qualified symbol name of a method → VarlinkMethod */
+        Hashmap *interfaces;           /* Fully qualified interface name → VarlinkInterface* */
+        Hashmap *symbols;              /* Fully qualified symbol name of methord/error → VarlinkSymbol* */
         VarlinkConnect connect_callback;
         VarlinkDisconnect disconnect_callback;
 
@@ -219,7 +224,7 @@ struct VarlinkServer {
         int64_t event_priority;
 
         unsigned n_connections;
-        Hashmap *by_uid;
+        Hashmap *by_uid;               /* UID_TO_PTR(uid) → UINT_TO_PTR(n_connections) */
 
         void *userdata;
         char *description;
@@ -443,6 +448,7 @@ static void varlink_clear_current(Varlink *v) {
 
         /* Clears the currently processed incoming message */
         v->current = json_variant_unref(v->current);
+        v->current_method = NULL;
 
         close_many(v->input_fds, v->n_input_fds);
         v->input_fds = mfree(v->input_fds);
@@ -988,10 +994,84 @@ invalid:
         return 1;
 }
 
+static int generic_method_get_info(
+                Varlink *link,
+                JsonVariant *parameters,
+                VarlinkMethodFlags flags,
+                void *userdata) {
+
+        _cleanup_strv_free_ char **interfaces = NULL;
+        _cleanup_free_ char *product = NULL;
+        int r;
+
+        assert(link);
+
+        if (json_variant_elements(parameters) != 0)
+                return varlink_errorb(link, VARLINK_ERROR_INVALID_PARAMETER,
+                                      JSON_BUILD_OBJECT(
+                                                      JSON_BUILD_PAIR_VARIANT("parameter", json_variant_by_index(parameters, 0))));
+
+        product = strjoin("systemd (", program_invocation_short_name, ")");
+        if (!product)
+                return -ENOMEM;
+
+        VarlinkInterface *interface;
+        HASHMAP_FOREACH(interface, ASSERT_PTR(link->server)->interfaces) {
+                r = strv_extend(&interfaces, interface->name);
+                if (r < 0)
+                        return r;
+        }
+
+        strv_sort(interfaces);
+
+        return varlink_replyb(link, JSON_BUILD_OBJECT(
+                                              JSON_BUILD_PAIR_STRING("vendor", "The systemd Project"),
+                                              JSON_BUILD_PAIR_STRING("product", product),
+                                              JSON_BUILD_PAIR_STRING("version", STRINGIFY(PROJECT_VERSION) " (" GIT_VERSION ")"),
+                                              JSON_BUILD_PAIR_STRING("url", "https://systemd.io/"),
+                                              JSON_BUILD_PAIR_STRV("interfaces", interfaces)));
+}
+
+static int generic_method_get_interface_description(
+                Varlink *link,
+                JsonVariant *parameters,
+                VarlinkMethodFlags flags,
+                void *userdata) {
+
+        static const struct JsonDispatch dispatch_table[] = {
+                { "interface",  JSON_VARIANT_STRING, json_dispatch_const_string, 0, JSON_MANDATORY },
+                {}
+        };
+        _cleanup_free_ char *text = NULL;
+        const VarlinkInterface *interface;
+        const char *name = NULL;
+        int r;
+
+        assert(link);
+
+        r = json_dispatch(parameters, dispatch_table, NULL, 0, &name);
+        if (r < 0)
+                return r;
+
+        interface = hashmap_get(ASSERT_PTR(link->server)->interfaces, name);
+        if (!interface)
+                return varlink_errorb(link, VARLINK_ERROR_INTERFACE_NOT_FOUND,
+                                      JSON_BUILD_OBJECT(
+                                                      JSON_BUILD_PAIR_STRING("interface", name)));
+
+        r = varlink_idl_format(interface, &text);
+        if (r < 0)
+                return r;
+
+        return varlink_replyb(link,
+                           JSON_BUILD_OBJECT(
+                                           JSON_BUILD_PAIR_STRING("description", text)));
+}
+
 static int varlink_dispatch_method(Varlink *v) {
         _cleanup_(json_variant_unrefp) JsonVariant *parameters = NULL;
         VarlinkMethodFlags flags = 0;
-        const char *method = NULL, *error;
+        const char *method = NULL;
         JsonVariant *e;
         VarlinkMethod callback;
         const char *k;
@@ -1064,37 +1144,51 @@ static int varlink_dispatch_method(Varlink *v) {
 
         assert(v->server);
 
-        if (STR_IN_SET(method, "org.varlink.service.GetInfo", "org.varlink.service.GetInterface")) {
-                /* For now, we don't implement a single of varlink's own methods */
-                callback = NULL;
-                error = VARLINK_ERROR_METHOD_NOT_IMPLEMENTED;
-        } else if (startswith(method, "org.varlink.service.")) {
-                callback = NULL;
-                error = VARLINK_ERROR_METHOD_NOT_FOUND;
-        } else {
-                callback = hashmap_get(v->server->methods, method);
-                error = VARLINK_ERROR_METHOD_NOT_FOUND;
+        /* First consult user supplied method implementations */
+        callback = hashmap_get(v->server->methods, method);
+        if (!callback) {
+                if (streq(method, "org.varlink.service.GetInfo"))
+                        callback = generic_method_get_info;
+                else if (streq(method, "org.varlink.service.GetInterfaceDescription"))
+                        callback = generic_method_get_interface_description;
         }
 
         if (callback) {
-                r = callback(v, parameters, flags, v->userdata);
-                if (r < 0) {
-                        log_debug_errno(r, "Callback for %s returned error: %m", method);
+                bool invalid = false;
+
+                v->current_method = hashmap_get(v->server->symbols, method);
+                if (!v->current_method)
+                        log_debug("No interface description defined for method '%s', not validating.", method);
+                else {
+                        const char *bad_field;
+
+                        r = varlink_idl_validate_method_call(v->current_method, parameters, &bad_field);
+                        if (r < 0) {
+                                log_debug_errno(r, "Parameters for method %s() didn't pass validation on field '%s': %m", method, strna(bad_field));
+                                r = varlink_errorb(v, VARLINK_ERROR_INVALID_PARAMETER, JSON_BUILD_OBJECT(JSON_BUILD_PAIR_STRING("parameter", bad_field)));
+                                invalid = true;
+                        }
+                }
 
-                        /* We got an error back from the callback. Propagate it to the client if the method call remains unanswered. */
-                        if (!FLAGS_SET(flags, VARLINK_METHOD_ONEWAY)) {
-                                r = varlink_error_errno(v, r);
-                                if (r < 0)
-                                        return r;
+                if (!invalid) {
+                        r = callback(v, parameters, flags, v->userdata);
+                        if (r < 0) {
+                                log_debug_errno(r, "Callback for %s returned error: %m", method);
+
+                                /* We got an error back from the callback. Propagate it to the client if the method call remains unanswered. */
+                                if (!FLAGS_SET(flags, VARLINK_METHOD_ONEWAY)) {
+                                        r = varlink_error_errno(v, r);
+                                        if (r < 0)
+                                                return r;
+                                }
                         }
                 }
         } else if (!FLAGS_SET(flags, VARLINK_METHOD_ONEWAY)) {
-                assert(error);
-
-                r = varlink_errorb(v, error, JSON_BUILD_OBJECT(JSON_BUILD_PAIR("method", JSON_BUILD_STRING(method))));
+                r = varlink_errorb(v, VARLINK_ERROR_METHOD_NOT_FOUND, JSON_BUILD_OBJECT(JSON_BUILD_PAIR("method", JSON_BUILD_STRING(method))));
                 if (r < 0)
                         return r;
-        }
+        } else
+                r = 0;
 
         switch (v->state) {
 
@@ -1114,7 +1208,6 @@ static int varlink_dispatch_method(Varlink *v) {
 
         default:
                 assert_not_reached();
-
         }
 
         return r;
@@ -1859,6 +1952,14 @@ int varlink_reply(Varlink *v, JsonVariant *parameters) {
         if (r < 0)
                 return varlink_log_errno(v, r, "Failed to build json message: %m");
 
+        if (v->current_method) {
+                const char *bad_field = NULL;
+
+                r = varlink_idl_validate_method_reply(v->current_method, parameters, &bad_field);
+                if (r < 0)
+                        log_debug_errno(r, "Return parameters for method reply %s() didn't pass validation on field '%s', ignoring: %m", v->current_method->name, strna(bad_field));
+        }
+
         r = varlink_enqueue_json(v, m);
         if (r < 0)
                 return varlink_log_errno(v, r, "Failed to enqueue json message: %m");
@@ -1925,6 +2026,17 @@ int varlink_error(Varlink *v, const char *error_id, JsonVariant *parameters) {
         if (r < 0)
                 return varlink_log_errno(v, r, "Failed to build json message: %m");
 
+        VarlinkSymbol *symbol = hashmap_get(v->server->symbols, error_id);
+        if (!symbol)
+                log_debug("No interface description defined for error '%s', not validating.", error_id);
+        else {
+                const char *bad_field = NULL;
+
+                r = varlink_idl_validate_method_reply(symbol, parameters, &bad_field);
+                if (r < 0)
+                        log_debug_errno(r, "Parameters for error %s didn't pass validation on field '%s', ignoring: %m", error_id, strna(bad_field));
+        }
+
         r = varlink_enqueue_json(v, m);
         if (r < 0)
                 return varlink_log_errno(v, r, "Failed to enqueue json message: %m");
@@ -2026,6 +2138,14 @@ int varlink_notify(Varlink *v, JsonVariant *parameters) {
         if (r < 0)
                 return varlink_log_errno(v, r, "Failed to build json message: %m");
 
+        if (v->current_method) {
+                const char *bad_field = NULL;
+
+                r = varlink_idl_validate_method_reply(v->current_method, parameters, &bad_field);
+                if (r < 0)
+                        log_debug_errno(r, "Return parameters for method reply %s() didn't pass validation on field '%s', ignoring: %m", v->current_method->name, strna(bad_field));
+        }
+
         r = varlink_enqueue_json(v, m);
         if (r < 0)
                 return varlink_log_errno(v, r, "Failed to enqueue json message: %m");
@@ -2455,7 +2575,8 @@ int varlink_set_allow_fd_passing_output(Varlink *v, bool b) {
 }
 
 int varlink_server_new(VarlinkServer **ret, VarlinkServerFlags flags) {
-        VarlinkServer *s;
+        _cleanup_(varlink_server_unrefp) VarlinkServer *s = NULL;
+        int r;
 
         assert_return(ret, -EINVAL);
         assert_return((flags & ~_VARLINK_SERVER_FLAGS_ALL) == 0, -EINVAL);
@@ -2471,7 +2592,14 @@ int varlink_server_new(VarlinkServer **ret, VarlinkServerFlags flags) {
                 .connections_per_uid_max = varlink_server_connections_per_uid_max(NULL),
         };
 
-        *ret = s;
+        r = varlink_server_add_interface_many(
+                        s,
+                        &vl_interface_io_systemd,
+                        &vl_interface_org_varlink_service);
+        if (r < 0)
+                return r;
+
+        *ret = TAKE_PTR(s);
         return 0;
 }
 
@@ -2488,6 +2616,7 @@ static VarlinkServer* varlink_server_destroy(VarlinkServer *s) {
 
         hashmap_free(s->methods);
         hashmap_free(s->interfaces);
+        hashmap_free(s->symbols);
         hashmap_free(s->by_uid);
 
         sd_event_unref(s->event);
@@ -2891,6 +3020,22 @@ sd_event *varlink_server_get_event(VarlinkServer *s) {
         return s->event;
 }
 
+static bool varlink_symbol_in_interface(const char *method, const char *interface) {
+        const char *p;
+
+        assert(method);
+        assert(interface);
+
+        p = startswith(method, interface);
+        if (!p)
+                return false;
+
+        if (*p != '.')
+                return false;
+
+        return !strchr(p+1, '.');
+}
+
 int varlink_server_bind_method(VarlinkServer *s, const char *method, VarlinkMethod callback) {
         _cleanup_free_ char *m = NULL;
         int r;
@@ -2899,7 +3044,8 @@ int varlink_server_bind_method(VarlinkServer *s, const char *method, VarlinkMeth
         assert_return(method, -EINVAL);
         assert_return(callback, -EINVAL);
 
-        if (startswith(method, "org.varlink.service."))
+        if (varlink_symbol_in_interface(method, "org.varlink.service") ||
+            varlink_symbol_in_interface(method, "io.systemd"))
                 return log_debug_errno(SYNTHETIC_ERRNO(EEXIST), "Cannot bind server to '%s'.", method);
 
         m = strdup(method);
@@ -2964,6 +3110,8 @@ int varlink_server_bind_disconnect(VarlinkServer *s, VarlinkDisconnect callback)
 }
 
 int varlink_server_add_interface(VarlinkServer *s, const VarlinkInterface *interface) {
+        int r;
+
         assert_return(s, -EINVAL);
         assert_return(interface, -EINVAL);
         assert_return(interface->name, -EINVAL);
@@ -2971,7 +3119,30 @@ int varlink_server_add_interface(VarlinkServer *s, const VarlinkInterface *inter
         if (hashmap_contains(s->interfaces, interface->name))
                 return log_debug_errno(SYNTHETIC_ERRNO(EEXIST), "Duplicate registration of interface '%s'.", interface->name);
 
-        return hashmap_ensure_put(&s->interfaces, &string_hash_ops, interface->name, (void*) interface);
+        r = hashmap_ensure_put(&s->interfaces, &string_hash_ops, interface->name, (void*) interface);
+        if (r < 0)
+                return r;
+
+        for (const VarlinkSymbol *const*symbol = interface->symbols; *symbol; symbol++) {
+                _cleanup_free_ char *j = NULL;
+
+                /* We only ever want to validate method calls/replies and errors against the interface
+                 * definitions, hence don't bother with the type symbols */
+                if (!IN_SET((*symbol)->symbol_type, VARLINK_METHOD, VARLINK_ERROR))
+                        continue;
+
+                j = strjoin(interface->name, ".", (*symbol)->name);
+                if (!j)
+                        return -ENOMEM;
+
+                r = hashmap_ensure_put(&s->symbols, &string_hash_ops_free, j, (void*) *symbol);
+                if (r < 0)
+                        return r;
+
+                TAKE_PTR(j);
+        }
+
+        return 0;
 }
 
 int varlink_server_add_interface_many_internal(VarlinkServer *s, ...) {