]> git.ipfire.org Git - thirdparty/wireguard-tools.git/commitdiff
wg: improve error reporting and detection
authorJason A. Donenfeld <Jason@zx2c4.com>
Sun, 3 Jul 2016 18:06:33 +0000 (20:06 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Sun, 3 Jul 2016 18:45:48 +0000 (20:45 +0200)
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
src/config.c
src/genkey.c
src/pubkey.c
src/wg.c

index 0cec30e4e393170d74a615e6715fdc27c30816d4..9066178979edfbf7b5d853114635d8afab9baa6d 100644 (file)
@@ -93,12 +93,8 @@ static inline uint16_t parse_port(const char *value)
 static inline bool parse_key(uint8_t key[WG_KEY_LEN], const char *value)
 {
        uint8_t tmp[WG_KEY_LEN + 1];
-       if (strlen(value) != b64_len(WG_KEY_LEN) - 1) {
-               fprintf(stderr, "Key is not the correct length: `%s`\n", value);
-               return false;
-       }
-       if (b64_pton(value, tmp, WG_KEY_LEN + 1) < 0) {
-               fprintf(stderr, "Could not parse base64 key: `%s`\n", value);
+       if (strlen(value) != b64_len(WG_KEY_LEN) - 1 || b64_pton(value, tmp, WG_KEY_LEN + 1) != WG_KEY_LEN) {
+               fprintf(stderr, "Key is not the correct length or format: `%s`\n", value);
                return false;
        }
        memcpy(key, tmp, WG_KEY_LEN);
index 1602ae1893a8317dc08b4c3a44290e419891d549..8e6310861248fcc8f79af8afb9d557e9e1e04425 100644 (file)
@@ -11,6 +11,7 @@
 
 #include "curve25519.h"
 #include "base64.h"
+#include "subcommands.h"
 
 #ifdef __NR_getrandom
 static inline ssize_t get_random_bytes(uint8_t *out, size_t len)
@@ -37,6 +38,11 @@ int genkey_main(int argc, char *argv[])
        char private_key_base64[b64_len(CURVE25519_POINT_SIZE)];
        struct stat stat;
 
+       if (argc != 1) {
+               fprintf(stderr, "Usage: %s %s\n", PROG_NAME, argv[0]);
+               return 1;
+       }
+
        if (!fstat(STDOUT_FILENO, &stat) && S_ISREG(stat.st_mode) && stat.st_mode & S_IRWXO)
                fputs("Warning: writing to world accessible file.\nConsider setting the umask to 077 and trying again.\n", stderr);
 
@@ -47,9 +53,8 @@ int genkey_main(int argc, char *argv[])
        if (argc && !strcmp(argv[0], "genkey"))
                curve25519_normalize_secret(private_key);
 
-       if (b64_ntop(private_key, sizeof(private_key), private_key_base64, sizeof(private_key_base64)) < 0) {
-               errno = EINVAL;
-               perror("b64");
+       if (b64_ntop(private_key, sizeof(private_key), private_key_base64, sizeof(private_key_base64)) != sizeof(private_key_base64) - 1) {
+               fprintf(stderr, "%s: Could not convert key to base64\n", PROG_NAME);
                return 1;
        }
 
index d9a97d93b363db600b3787aca66f13ddf2d3c0dd..452c8fa2f77c32f9a03530b409fd2950a58ac4f7 100644 (file)
@@ -3,29 +3,46 @@
 #include <errno.h>
 #include <resolv.h>
 #include <stdio.h>
+#include <ctype.h>
 
 #include "curve25519.h"
 #include "base64.h"
+#include "subcommands.h"
 
-int pubkey_main(__attribute__((unused)) int argc, __attribute__((unused)) char *argv[])
+int pubkey_main(int argc, char *argv[])
 {
        unsigned char private_key[CURVE25519_POINT_SIZE + 1] = { 0 }, public_key[CURVE25519_POINT_SIZE] = { 0 };
        char private_key_base64[b64_len(CURVE25519_POINT_SIZE)] = { 0 }, public_key_base64[b64_len(CURVE25519_POINT_SIZE)] = { 0 };
+       int trailing_char;
+
+       if (argc != 1) {
+               fprintf(stderr, "Usage: %s %s\n", PROG_NAME, argv[0]);
+               return 1;
+       }
 
        if (fread(private_key_base64, 1, sizeof(private_key_base64) - 1, stdin) != sizeof(private_key_base64) - 1) {
                errno = EINVAL;
-               perror("fread(private key)");
+               fprintf(stderr, "%s: Key is not the correct length or format\n", PROG_NAME);
                return 1;
        }
-       if (b64_pton(private_key_base64, private_key, sizeof(private_key)) < 0) {
-               errno = EINVAL;
-               perror("b64");
+
+       for (;;) {
+               trailing_char = getc(stdin);
+               if (!trailing_char || isspace(trailing_char) || isblank(trailing_char))
+                       continue;
+               if (trailing_char == EOF)
+                       break;
+               fprintf(stderr, "%s: Trailing characters found after key\n", PROG_NAME);
+               return 1;
+       }
+
+       if (b64_pton(private_key_base64, private_key, sizeof(private_key)) != sizeof(private_key) - 1) {
+               fprintf(stderr, "%s: Key is not the correct length or format\n", PROG_NAME);
                return 1;
        }
        curve25519_generate_public(public_key, private_key);
-       if (b64_ntop(public_key, sizeof(public_key), public_key_base64, sizeof(public_key_base64)) < 0) {
-               errno = EINVAL;
-               perror("b64");
+       if (b64_ntop(public_key, sizeof(public_key), public_key_base64, sizeof(public_key_base64)) != sizeof(public_key_base64) - 1) {
+               fprintf(stderr, "%s: Could not convert key to base64\n", PROG_NAME);
                return 1;
        }
        puts(public_key_base64);
index d4d2965967183e761ad62f1015a37bddae4ab8a8..ee19387c5579fdf3e94bb295931252b9fec1dd9c 100644 (file)
--- a/src/wg.c
+++ b/src/wg.c
@@ -23,12 +23,13 @@ static const struct {
        { "pubkey", pubkey_main, "Reads a private key from stdin and writes a public key to stdout" }
 };
 
-static void show_usage(void)
+static void show_usage(FILE *file)
 {
-       fprintf(stderr, "Usage: %s <cmd> [<args>]\n\n", PROG_NAME);
-       fprintf(stderr, "Available subcommands:\n");
+       fprintf(file, "Usage: %s <cmd> [<args>]\n\n", PROG_NAME);
+       fprintf(file, "Available subcommands:\n");
        for (size_t i = 0; i < sizeof(subcommands) / sizeof(subcommands[0]); ++i)
-               fprintf(stderr, "  %s: %s\n", subcommands[i].subcommand, subcommands[i].description);
+               fprintf(file, "  %s: %s\n", subcommands[i].subcommand, subcommands[i].description);
+       fprintf(file, "You may pass `--help' to any of these subcommands to view usage.\n");
 }
 
 int main(int argc, char *argv[])
@@ -37,8 +38,8 @@ int main(int argc, char *argv[])
        PROG_NAME = argv[0];
 
        if (argc == 2 && (!strcmp(argv[1], "-h") || !strcmp(argv[1], "--help") || !strcmp(argv[1], "help"))) {
-               show_usage();
-               return 1;
+               show_usage(stdout);
+               return 0;
        }
 
        if (argc == 1) {
@@ -61,6 +62,6 @@ findsubcommand:
        }
 
        fprintf(stderr, "Invalid subcommand: `%s`\n", argv[1]);
-       show_usage();
+       show_usage(stderr);
        return 1;
 }