]> git.ipfire.org Git - thirdparty/wireguard-tools.git/commitdiff
wg: fix removing preshared keys
authorJason A. Donenfeld <Jason@zx2c4.com>
Thu, 23 Nov 2017 00:17:25 +0000 (01:17 +0100)
committerJason A. Donenfeld <Jason@zx2c4.com>
Thu, 23 Nov 2017 10:09:12 +0000 (11:09 +0100)
Also clean up related logic quite a bit and add unit tests.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
src/config.c

index 84038c836de9c86493e35b846f00a718c266d655..6ff03767f9e6a4bcb63a51307245bfbee0e58e27 100644 (file)
@@ -105,11 +105,59 @@ static inline bool parse_key(uint8_t key[static WG_KEY_LEN], const char *value)
 {
        if (!key_from_base64(key, value)) {
                fprintf(stderr, "Key is not the correct length or format: `%s'\n", value);
+               memset(key, 0, WG_KEY_LEN);
                return false;
        }
        return true;
 }
 
+static bool parse_keyfile(uint8_t key[static WG_KEY_LEN], const char *path)
+{
+       FILE *f;
+       int c;
+       char dst[WG_KEY_LEN_BASE64];
+       bool ret = false;
+
+       f = fopen(path, "r");
+       if (!f) {
+               perror("fopen");
+               return false;
+       }
+
+       if (fread(dst, WG_KEY_LEN_BASE64 - 1, 1, f) != 1) {
+               if (errno) {
+                       perror("fread");
+                       goto out;
+               }
+               /* If we're at the end and we didn't read anything, we're /dev/null or an empty file. */
+               if (!ferror(f) && feof(f) && !ftell(f)) {
+                       memset(key, 0, WG_KEY_LEN);
+                       ret = true;
+                       goto out;
+               }
+
+               fprintf(stderr, "Invalid length key in key file\n");
+               goto out;
+       }
+       dst[WG_KEY_LEN_BASE64 - 1] = '\0';
+
+       while ((c = getc(f)) != EOF) {
+               if (!isspace(c)) {
+                       fprintf(stderr, "Found trailing character in key file: `%c'\n", c);
+                       goto out;
+               }
+       }
+       if (ferror(f) && errno) {
+               perror("getc");
+               goto out;
+       }
+       ret = parse_key(key, dst);
+
+out:
+       fclose(f);
+       return ret;
+}
+
 static inline bool parse_ip(struct wgallowedip *allowedip, const char *value)
 {
        allowedip->family = AF_UNSPEC;
@@ -335,9 +383,7 @@ static bool process_line(struct config_ctx *ctx, const char *line)
                        ret = parse_fwmark(&ctx->device->fwmark, &ctx->device->flags, value);
                else if (key_match("PrivateKey")) {
                        ret = parse_key(ctx->device->private_key, value);
-                       if (!ret)
-                               memset(ctx->device->private_key, 0, WG_KEY_LEN);
-                       else
+                       if (ret)
                                ctx->device->flags |= WGDEVICE_HAS_PRIVATE_KEY;
                } else
                        goto error;
@@ -354,9 +400,7 @@ static bool process_line(struct config_ctx *ctx, const char *line)
                        ret = parse_persistent_keepalive(&ctx->last_peer->persistent_keepalive_interval, &ctx->last_peer->flags, value);
                else if (key_match("PresharedKey")) {
                        ret = parse_key(ctx->last_peer->preshared_key, value);
-                       if (!ret)
-                               memset(ctx->last_peer->preshared_key, 0, WG_KEY_LEN);
-                       else if (!key_is_zero(ctx->last_peer->preshared_key))
+                       if (ret)
                                ctx->last_peer->flags |= WGPEER_HAS_PRESHARED_KEY;
                } else
                        goto error;
@@ -429,54 +473,6 @@ err:
        return NULL;
 }
 
-static bool read_keyfile(char dst[WG_KEY_LEN_BASE64], const char *path)
-{
-       FILE *f;
-       int c;
-       bool ret = false;
-
-       f = fopen(path, "r");
-       if (!f) {
-               perror("fopen");
-               return false;
-       }
-
-       if (fread(dst, WG_KEY_LEN_BASE64 - 1, 1, f) != 1) {
-               if (errno) {
-                       perror("fread");
-                       goto out;
-               }
-               /* If we're at the end and we didn't read anything, we're /dev/null. */
-               if (!ferror(f) && feof(f) && !ftell(f)) {
-                       static const uint8_t zeros[WG_KEY_LEN] = { 0 };
-
-                       key_to_base64(dst, zeros);
-                       ret = true;
-                       goto out;
-               }
-
-               fprintf(stderr, "Invalid length key in key file\n");
-               goto out;
-       }
-       dst[WG_KEY_LEN_BASE64 - 1] = '\0';
-
-       while ((c = getc(f)) != EOF) {
-               if (!isspace(c)) {
-                       fprintf(stderr, "Found trailing character in key file: `%c'\n", c);
-                       goto out;
-               }
-       }
-       if (ferror(f) && errno) {
-               perror("getc");
-               goto out;
-       }
-       ret = true;
-
-out:
-       fclose(f);
-       return ret;
-}
-
 static char *strip_spaces(const char *in)
 {
        char *out;
@@ -517,14 +513,9 @@ struct wgdevice *config_read_cmd(char *argv[], int argc)
                        argv += 2;
                        argc -= 2;
                } else if (!strcmp(argv[0], "private-key") && argc >= 2 && !peer) {
-                       char key_line[WG_KEY_LEN_BASE64];
-
-                       if (read_keyfile(key_line, argv[1])) {
-                               if (!parse_key(device->private_key, key_line))
-                                       goto error;
-                               device->flags |= WGDEVICE_HAS_PRIVATE_KEY;
-                       } else
+                       if (!parse_keyfile(device->private_key, argv[1]))
                                goto error;
+                       device->flags |= WGDEVICE_HAS_PRIVATE_KEY;
                        argv += 2;
                        argc -= 2;
                } else if (!strcmp(argv[0], "peer") && argc >= 2) {
@@ -542,6 +533,7 @@ struct wgdevice *config_read_cmd(char *argv[], int argc)
                        peer = new_peer;
                        if (!parse_key(peer->public_key, argv[1]))
                                goto error;
+                       peer->flags |= WGPEER_HAS_PUBLIC_KEY;
                        argv += 2;
                        argc -= 2;
                } else if (!strcmp(argv[0], "remove") && argc >= 1 && peer) {
@@ -571,15 +563,9 @@ struct wgdevice *config_read_cmd(char *argv[], int argc)
                        argv += 2;
                        argc -= 2;
                } else if (!strcmp(argv[0], "preshared-key") && argc >= 2 && peer) {
-                       char key_line[WG_KEY_LEN_BASE64];
-
-                       if (read_keyfile(key_line, argv[1])) {
-                               if (!parse_key(peer->preshared_key, key_line))
-                                       goto error;
-                               if (!key_is_zero(peer->preshared_key))
-                                       peer->flags |= WGPEER_HAS_PRESHARED_KEY;
-                       } else
+                       if (!parse_keyfile(peer->preshared_key, argv[1]))
                                goto error;
+                       peer->flags |= WGPEER_HAS_PRESHARED_KEY;
                        argv += 2;
                        argc -= 2;
                } else {