]> git.ipfire.org Git - thirdparty/systemd.git/blobdiff - src/firstboot/firstboot.c
tree-wide: use sd_id128_is_null() instead of sd_id128_equal where appropriate
[thirdparty/systemd.git] / src / firstboot / firstboot.c
index 215c059ee201cd1759b913ebbc8612cd2db772c8..c9e8e54ee38f00b0d8f234eda830691526a2afd9 100644 (file)
@@ -1,5 +1,3 @@
-/*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
-
 /***
   This file is part of systemd.
 
   along with systemd; If not, see <http://www.gnu.org/licenses/>.
 ***/
 
-
 #include <fcntl.h>
-#include <unistd.h>
 #include <getopt.h>
 #include <shadow.h>
+#include <unistd.h>
 
-#include "strv.h"
-#include "fileio.h"
+#include "alloc-util.h"
+#include "ask-password-api.h"
 #include "copy.h"
-#include "build.h"
+#include "fd-util.h"
+#include "fileio.h"
+#include "fs-util.h"
+#include "hostname-util.h"
+#include "locale-util.h"
 #include "mkdir.h"
-#include "time-util.h"
+#include "parse-util.h"
 #include "path-util.h"
-#include "locale-util.h"
-#include "ask-password-api.h"
+#include "random-util.h"
+#include "string-util.h"
+#include "strv.h"
+#include "terminal-util.h"
+#include "time-util.h"
+#include "umask-util.h"
+#include "user-util.h"
 
 static char *arg_root = NULL;
 static char *arg_locale = NULL;  /* $LANG */
@@ -50,17 +56,6 @@ static bool arg_copy_locale = false;
 static bool arg_copy_timezone = false;
 static bool arg_copy_root_password = false;
 
-#define prefix_roota(p) (arg_root ? (const char*) strappenda(arg_root, p) : (const char*) p)
-
-static void clear_string(char *x) {
-
-        if (!x)
-                return;
-
-        /* A delicious drop of snake-oil! */
-        memset(x, 'x', strlen(x));
-}
-
 static bool press_any_key(void) {
         char k = 0;
         bool need_nl = true;
@@ -68,7 +63,7 @@ static bool press_any_key(void) {
         printf("-- Press any key to proceed --");
         fflush(stdout);
 
-        read_one_char(stdin, &k, USEC_INFINITY, &need_nl);
+        (void) read_one_char(stdin, &k, USEC_INFINITY, &need_nl);
 
         if (need_nl)
                 putchar('\n');
@@ -85,20 +80,20 @@ static void print_welcome(void) {
         if (done)
                 return;
 
-        os_release = prefix_roota("/etc/os-release");
+        os_release = prefix_roota(arg_root, "/etc/os-release");
         r = parse_env_file(os_release, NEWLINE,
                            "PRETTY_NAME", &pretty_name,
                            NULL);
         if (r == -ENOENT) {
 
-                os_release = prefix_roota("/usr/lib/os-release");
+                os_release = prefix_roota(arg_root, "/usr/lib/os-release");
                 r = parse_env_file(os_release, NEWLINE,
                                    "PRETTY_NAME", &pretty_name,
                                    NULL);
         }
 
         if (r < 0 && r != -ENOENT)
-                log_warning("Failed to read os-release file: %s", strerror(-r));
+                log_warning_errno(r, "Failed to read os-release file: %m");
 
         printf("\nWelcome to your new installation of %s!\nPlease configure a few basic system settings:\n\n",
                isempty(pretty_name) ? "Linux" : pretty_name);
@@ -165,11 +160,9 @@ static int prompt_loop(const char *text, char **l, bool (*is_valid)(const char *
                 _cleanup_free_ char *p = NULL;
                 unsigned u;
 
-                r = ask_string(&p, "%s %s (empty to skip): ", draw_special_char(DRAW_TRIANGULAR_BULLET), text);
-                if (r < 0) {
-                        log_error("Failed to query user: %s", strerror(-r));
-                        return r;
-                }
+                r = ask_string(&p, "%s %s (empty to skip): ", special_glyph(TRIANGULAR_BULLET), text);
+                if (r < 0)
+                        return log_error_errno(r, "Failed to query user: %m");
 
                 if (isempty(p)) {
                         log_warning("No data entered, skipping.");
@@ -219,10 +212,8 @@ static int prompt_locale(void) {
                 return 0;
 
         r = get_locales(&locales);
-        if (r < 0) {
-                log_error("Cannot query locales list: %s", strerror(-r));
-                return r;
-        }
+        if (r < 0)
+                return log_error_errno(r, "Cannot query locales list: %m");
 
         print_welcome();
 
@@ -253,19 +244,17 @@ static int process_locale(void) {
         unsigned i = 0;
         int r;
 
-        etc_localeconf = prefix_roota("/etc/locale.conf");
-        if (faccessat(AT_FDCWD, etc_localeconf, F_OK, AT_SYMLINK_NOFOLLOW) >= 0)
+        etc_localeconf = prefix_roota(arg_root, "/etc/locale.conf");
+        if (laccess(etc_localeconf, F_OK) >= 0)
                 return 0;
 
         if (arg_copy_locale && arg_root) {
 
                 mkdir_parents(etc_localeconf, 0755);
-                r = copy_file("/etc/locale.conf", etc_localeconf, 0, 0644);
+                r = copy_file("/etc/locale.conf", etc_localeconf, 0, 0644, 0);
                 if (r != -ENOENT) {
-                        if (r < 0) {
-                                log_error("Failed to copy %s: %s", etc_localeconf, strerror(-r));
-                                return r;
-                        }
+                        if (r < 0)
+                                return log_error_errno(r, "Failed to copy %s: %m", etc_localeconf);
 
                         log_info("%s copied.", etc_localeconf);
                         return 0;
@@ -277,9 +266,9 @@ static int process_locale(void) {
                 return r;
 
         if (!isempty(arg_locale))
-                locales[i++] = strappenda("LANG=", arg_locale);
+                locales[i++] = strjoina("LANG=", arg_locale);
         if (!isempty(arg_locale_messages) && !streq(arg_locale_messages, arg_locale))
-                locales[i++] = strappenda("LC_MESSAGES=", arg_locale_messages);
+                locales[i++] = strjoina("LC_MESSAGES=", arg_locale_messages);
 
         if (i == 0)
                 return 0;
@@ -288,10 +277,8 @@ static int process_locale(void) {
 
         mkdir_parents(etc_localeconf, 0755);
         r = write_env_file(etc_localeconf, locales);
-        if (r < 0) {
-                log_error("Failed to write %s: %s", etc_localeconf, strerror(-r));
-                return r;
-        }
+        if (r < 0)
+                return log_error_errno(r, "Failed to write %s: %m", etc_localeconf);
 
         log_info("%s written.", etc_localeconf);
         return 0;
@@ -308,10 +295,8 @@ static int prompt_timezone(void) {
                 return 0;
 
         r = get_timezones(&zones);
-        if (r < 0) {
-                log_error("Cannot query timezone list: %s", strerror(-r));
-                return r;
-        }
+        if (r < 0)
+                return log_error_errno(r, "Cannot query timezone list: %m");
 
         print_welcome();
 
@@ -333,8 +318,8 @@ static int process_timezone(void) {
         const char *etc_localtime, *e;
         int r;
 
-        etc_localtime = prefix_roota("/etc/localtime");
-        if (faccessat(AT_FDCWD, etc_localtime, F_OK, AT_SYMLINK_NOFOLLOW) >= 0)
+        etc_localtime = prefix_roota(arg_root, "/etc/localtime");
+        if (laccess(etc_localtime, F_OK) >= 0)
                 return 0;
 
         if (arg_copy_timezone && arg_root) {
@@ -342,16 +327,12 @@ static int process_timezone(void) {
 
                 r = readlink_malloc("/etc/localtime", &p);
                 if (r != -ENOENT) {
-                        if (r < 0) {
-                                log_error("Failed to read host timezone: %s", strerror(-r));
-                                return r;
-                        }
+                        if (r < 0)
+                                return log_error_errno(r, "Failed to read host timezone: %m");
 
                         mkdir_parents(etc_localtime, 0755);
-                        if (symlink(p, etc_localtime) < 0) {
-                                log_error("Failed to create %s symlink: %m", etc_localtime);
-                                return -errno;
-                        }
+                        if (symlink(p, etc_localtime) < 0)
+                                return log_error_errno(errno, "Failed to create %s symlink: %m", etc_localtime);
 
                         log_info("%s copied.", etc_localtime);
                         return 0;
@@ -365,13 +346,11 @@ static int process_timezone(void) {
         if (isempty(arg_timezone))
                 return 0;
 
-        e = strappenda("../usr/share/zoneinfo/", arg_timezone);
+        e = strjoina("../usr/share/zoneinfo/", arg_timezone);
 
         mkdir_parents(etc_localtime, 0755);
-        if (symlink(e, etc_localtime) < 0) {
-                log_error("Failed to create %s symlink: %m", etc_localtime);
-                return -errno;
-        }
+        if (symlink(e, etc_localtime) < 0)
+                return log_error_errno(errno, "Failed to create %s symlink: %m", etc_localtime);
 
         log_info("%s written", etc_localtime);
         return 0;
@@ -392,23 +371,22 @@ static int prompt_hostname(void) {
         for (;;) {
                 _cleanup_free_ char *h = NULL;
 
-                r = ask_string(&h, "%s Please enter hostname for new system (empty to skip): ", draw_special_char(DRAW_TRIANGULAR_BULLET));
-                if (r < 0) {
-                        log_error("Failed to query hostname: %s", strerror(-r));
-                        return r;
-                }
+                r = ask_string(&h, "%s Please enter hostname for new system (empty to skip): ", special_glyph(TRIANGULAR_BULLET));
+                if (r < 0)
+                        return log_error_errno(r, "Failed to query hostname: %m");
 
                 if (isempty(h)) {
                         log_warning("No hostname entered, skipping.");
                         break;
                 }
 
-                if (!hostname_is_valid(h)) {
+                if (!hostname_is_valid(h, true)) {
                         log_error("Specified hostname invalid.");
                         continue;
                 }
 
-                arg_hostname = h;
+                /* Get rid of the trailing dot that we allow, but don't want to see */
+                arg_hostname = hostname_cleanup(h);
                 h = NULL;
                 break;
         }
@@ -420,8 +398,8 @@ static int process_hostname(void) {
         const char *etc_hostname;
         int r;
 
-        etc_hostname = prefix_roota("/etc/hostname");
-        if (faccessat(AT_FDCWD, etc_hostname, F_OK, AT_SYMLINK_NOFOLLOW) >= 0)
+        etc_hostname = prefix_roota(arg_root, "/etc/hostname");
+        if (laccess(etc_hostname, F_OK) >= 0)
                 return 0;
 
         r = prompt_hostname();
@@ -432,11 +410,9 @@ static int process_hostname(void) {
                 return 0;
 
         mkdir_parents(etc_hostname, 0755);
-        r = write_string_file(etc_hostname, arg_hostname);
-        if (r < 0) {
-                log_error("Failed to write %s: %s", etc_hostname, strerror(-r));
-                return r;
-        }
+        r = write_string_file(etc_hostname, arg_hostname, WRITE_STRING_FILE_CREATE);
+        if (r < 0)
+                return log_error_errno(r, "Failed to write %s: %m", etc_hostname);
 
         log_info("%s written.", etc_hostname);
         return 0;
@@ -447,22 +423,17 @@ static int process_machine_id(void) {
         char id[SD_ID128_STRING_MAX];
         int r;
 
-        etc_machine_id = prefix_roota("/etc/machine-id");
-        if (faccessat(AT_FDCWD, etc_machine_id, F_OK, AT_SYMLINK_NOFOLLOW) >= 0)
-                return 0;
-
-        if (!arg_root)
+        etc_machine_id = prefix_roota(arg_root, "/etc/machine-id");
+        if (laccess(etc_machine_id, F_OK) >= 0)
                 return 0;
 
-        if (sd_id128_equal(arg_machine_id, SD_ID128_NULL))
+        if (sd_id128_is_null(arg_machine_id))
                 return 0;
 
         mkdir_parents(etc_machine_id, 0755);
-        r = write_string_file(etc_machine_id, sd_id128_to_string(arg_machine_id, id));
-        if (r < 0) {
-                log_error("Failed to write machine id: %s", strerror(-r));
-                return r;
-        }
+        r = write_string_file(etc_machine_id, sd_id128_to_string(arg_machine_id, id), WRITE_STRING_FILE_CREATE);
+        if (r < 0)
+                return log_error_errno(r, "Failed to write machine id: %m");
 
         log_info("%s written.", etc_machine_id);
         return 0;
@@ -478,45 +449,37 @@ static int prompt_root_password(void) {
         if (!arg_prompt_root_password)
                 return 0;
 
-        etc_shadow = prefix_roota("/etc/shadow");
-        if (faccessat(AT_FDCWD, etc_shadow, F_OK, AT_SYMLINK_NOFOLLOW) >= 0)
+        etc_shadow = prefix_roota(arg_root, "/etc/shadow");
+        if (laccess(etc_shadow, F_OK) >= 0)
                 return 0;
 
         print_welcome();
         putchar('\n');
 
-        msg1 = strappenda(draw_special_char(DRAW_TRIANGULAR_BULLET), " Please enter a new root password (empty to skip): ");
-        msg2 = strappenda(draw_special_char(DRAW_TRIANGULAR_BULLET), " Please enter new root password again: ");
+        msg1 = strjoina(special_glyph(TRIANGULAR_BULLET), " Please enter a new root password (empty to skip): ");
+        msg2 = strjoina(special_glyph(TRIANGULAR_BULLET), " Please enter new root password again: ");
 
         for (;;) {
-                _cleanup_free_ char *a = NULL, *b = NULL;
+                _cleanup_string_free_erase_ char *a = NULL, *b = NULL;
 
-                r = ask_password_tty(msg1, 0, NULL, &a);
-                if (r < 0) {
-                        log_error("Failed to query root password: %s", strerror(-r));
-                        return r;
-                }
+                r = ask_password_tty(msg1, NULL, 0, 0, NULL, &a);
+                if (r < 0)
+                        return log_error_errno(r, "Failed to query root password: %m");
 
                 if (isempty(a)) {
                         log_warning("No password entered, skipping.");
                         break;
                 }
 
-                r = ask_password_tty(msg2, 0, NULL, &b);
-                if (r < 0) {
-                        log_error("Failed to query root password: %s", strerror(-r));
-                        clear_string(a);
-                        return r;
-                }
+                r = ask_password_tty(msg2, NULL, 0, 0, NULL, &b);
+                if (r < 0)
+                        return log_error_errno(r, "Failed to query root password: %m");
 
                 if (!streq(a, b)) {
                         log_error("Entered passwords did not match, please try again.");
-                        clear_string(a);
-                        clear_string(b);
                         continue;
                 }
 
-                clear_string(b);
                 arg_root_password = a;
                 a = NULL;
                 break;
@@ -537,7 +500,7 @@ static int write_root_shadow(const char *path, const struct spwd *p) {
 
         errno = 0;
         if (putspent(p, f) != 0)
-                return errno ? -errno : -EIO;
+                return errno > 0 ? -errno : -EIO;
 
         return fflush_and_check(f);
 }
@@ -552,9 +515,9 @@ static int process_root_password(void) {
 
         struct spwd item = {
                 .sp_namp = (char*) "root",
-                .sp_min = 0,
-                .sp_max = 99999,
-                .sp_warn = 7,
+                .sp_min = -1,
+                .sp_max = -1,
+                .sp_warn = -1,
                 .sp_inact = -1,
                 .sp_expire = -1,
                 .sp_flag = (unsigned long) -1, /* this appears to be what everybody does ... */
@@ -569,15 +532,15 @@ static int process_root_password(void) {
         const char *etc_shadow;
         int r;
 
-        etc_shadow = prefix_roota("/etc/shadow");
-        if (faccessat(AT_FDCWD, etc_shadow, F_OK, AT_SYMLINK_NOFOLLOW) >= 0)
+        etc_shadow = prefix_roota(arg_root, "/etc/shadow");
+        if (laccess(etc_shadow, F_OK) >= 0)
                 return 0;
 
         mkdir_parents(etc_shadow, 0755);
 
-        lock = take_password_lock(arg_root);
+        lock = take_etc_passwd_lock(arg_root);
         if (lock < 0)
-                return lock;
+                return log_error_errno(lock, "Failed to take a lock: %m");
 
         if (arg_copy_root_password && arg_root) {
                 struct spwd *p;
@@ -589,15 +552,12 @@ static int process_root_password(void) {
                                 if (!errno)
                                         errno = EIO;
 
-                                log_error("Failed to find shadow entry for root: %m");
-                                return -errno;
+                                return log_error_errno(errno, "Failed to find shadow entry for root: %m");
                         }
 
                         r = write_root_shadow(etc_shadow, p);
-                        if (r < 0) {
-                                log_error("Failed to write %s: %s", etc_shadow, strerror(-r));
-                                return r;
-                        }
+                        if (r < 0)
+                                return log_error_errno(r, "Failed to write %s: %m", etc_shadow);
 
                         log_info("%s copied.", etc_shadow);
                         return 0;
@@ -612,10 +572,8 @@ static int process_root_password(void) {
                 return 0;
 
         r = dev_urandom(raw, 16);
-        if (r < 0) {
-                log_error("Failed to get salt: %s", strerror(-r));
-                return r;
-        }
+        if (r < 0)
+                return log_error_errno(r, "Failed to get salt: %m");
 
         /* We only bother with SHA512 hashed passwords, the rest is legacy, and we don't do legacy. */
         assert_cc(sizeof(table) == 64 + 1);
@@ -629,19 +587,16 @@ static int process_root_password(void) {
         item.sp_pwdp = crypt(arg_root_password, salt);
         if (!item.sp_pwdp) {
                 if (!errno)
-                        errno = -EINVAL;
+                        errno = EINVAL;
 
-                log_error("Failed to encrypt password: %m");
-                return -errno;
+                return log_error_errno(errno, "Failed to encrypt password: %m");
         }
 
         item.sp_lstchg = (long) (now(CLOCK_REALTIME) / USEC_PER_DAY);
 
         r = write_root_shadow(etc_shadow, &item);
-        if (r < 0) {
-                log_error("Failed to write %s: %s", etc_shadow, strerror(-r));
-                return r;
-        }
+        if (r < 0)
+                return log_error_errno(r, "Failed to write %s: %m", etc_shadow);
 
         log_info("%s written.", etc_shadow);
         return 0;
@@ -664,7 +619,7 @@ static void help(void) {
                "     --prompt-timezone         Prompt the user for timezone\n"
                "     --prompt-hostname         Prompt the user for hostname\n"
                "     --prompt-root-password    Prompt the user for root password\n"
-               "     --prompt                  Prompt for locale, timezone, hostname, root password\n"
+               "     --prompt                  Prompt for all of the above\n"
                "     --copy-locale             Copy locale from host\n"
                "     --copy-timezone           Copy timezone from host\n"
                "     --copy-root-password      Copy root password from host\n"
@@ -735,23 +690,12 @@ static int parse_argv(int argc, char *argv[]) {
                         return 0;
 
                 case ARG_VERSION:
-                        puts(PACKAGE_STRING);
-                        puts(SYSTEMD_FEATURES);
-                        return 0;
+                        return version();
 
                 case ARG_ROOT:
-                        free(arg_root);
-                        arg_root = path_make_absolute_cwd(optarg);
-                        if (!arg_root)
-                                return log_oom();
-
-                        path_kill_slashes(arg_root);
-
-                        if (path_equal(arg_root, "/")) {
-                                free(arg_root);
-                                arg_root = NULL;
-                        }
-
+                        r = parse_path_argument_and_warn(optarg, true, &arg_root);
+                        if (r < 0)
+                                return r;
                         break;
 
                 case ARG_LOCALE:
@@ -760,9 +704,8 @@ static int parse_argv(int argc, char *argv[]) {
                                 return -EINVAL;
                         }
 
-                        free(arg_locale);
-                        arg_locale = strdup(optarg);
-                        if (!arg_locale)
+                        r = free_and_strdup(&arg_locale, optarg);
+                        if (r < 0)
                                 return log_oom();
 
                         break;
@@ -773,9 +716,8 @@ static int parse_argv(int argc, char *argv[]) {
                                 return -EINVAL;
                         }
 
-                        free(arg_locale_messages);
-                        arg_locale_messages = strdup(optarg);
-                        if (!arg_locale_messages)
+                        r = free_and_strdup(&arg_locale_messages, optarg);
+                        if (r < 0)
                                 return log_oom();
 
                         break;
@@ -786,42 +728,36 @@ static int parse_argv(int argc, char *argv[]) {
                                 return -EINVAL;
                         }
 
-                        free(arg_timezone);
-                        arg_timezone = strdup(optarg);
-                        if (!arg_timezone)
+                        r = free_and_strdup(&arg_timezone, optarg);
+                        if (r < 0)
                                 return log_oom();
 
                         break;
 
                 case ARG_ROOT_PASSWORD:
-                        free(arg_root_password);
-                        arg_root_password = strdup(optarg);
-                        if (!arg_root_password)
+                        r = free_and_strdup(&arg_root_password, optarg);
+                        if (r < 0)
                                 return log_oom();
-
                         break;
 
                 case ARG_ROOT_PASSWORD_FILE:
-                        free(arg_root_password);
-                        arg_root_password  = NULL;
+                        arg_root_password = mfree(arg_root_password);
 
                         r = read_one_line_file(optarg, &arg_root_password);
-                        if (r < 0) {
-                                log_error("Failed to read %s: %s", optarg, strerror(-r));
-                                return r;
-                        }
+                        if (r < 0)
+                                return log_error_errno(r, "Failed to read %s: %m", optarg);
 
                         break;
 
                 case ARG_HOSTNAME:
-                        if (!hostname_is_valid(optarg)) {
+                        if (!hostname_is_valid(optarg, true)) {
                                 log_error("Host name %s is not valid.", optarg);
                                 return -EINVAL;
                         }
 
-                        free(arg_hostname);
-                        arg_hostname = strdup(optarg);
-                        if (!arg_hostname)
+                        hostname_cleanup(optarg);
+                        r = free_and_strdup(&arg_hostname, optarg);
+                        if (r < 0)
                                 return log_oom();
 
                         break;
@@ -873,10 +809,8 @@ static int parse_argv(int argc, char *argv[]) {
                 case ARG_SETUP_MACHINE_ID:
 
                         r = sd_id128_randomize(&arg_machine_id);
-                        if (r < 0) {
-                                log_error("Failed to generate randomized machine ID: %s", strerror(-r));
-                                return r;
-                        }
+                        if (r < 0)
+                                return log_error_errno(r, "Failed to generate randomized machine ID: %m");
 
                         break;
 
@@ -929,7 +863,7 @@ finish:
         free(arg_locale_messages);
         free(arg_timezone);
         free(arg_hostname);
-        clear_string(arg_root_password);
+        string_erase(arg_root_password);
         free(arg_root_password);
 
         return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;