]> git.ipfire.org Git - thirdparty/dovecot/core.git/commitdiff
lib-oauth2: Ensure token algorithm matches with key
authorAki Tuomi <aki.tuomi@open-xchange.com>
Tue, 2 Jun 2020 12:52:34 +0000 (15:52 +0300)
committerAki Tuomi <aki.tuomi@open-xchange.com>
Fri, 5 Jun 2020 06:12:08 +0000 (09:12 +0300)
Otherwise we might mistakenly use key that is not intended
for the token.

src/lib-oauth2/oauth2-jwt.c
src/lib-oauth2/test-oauth2-jwt.c

index 7b0aaa0482fd17689daf60a24f5e7c12a71eb534..93fe81f5de960b751584733591d269da14fb86c3 100644 (file)
@@ -45,21 +45,23 @@ static int get_time_field(const struct json_tree *tree, const char *key,
 }
 
 static int oauth2_lookup_hmac_key(const struct oauth2_settings *set,
-                                 const char *key_id, const buffer_t **hmac_key_r,
+                                 const char *algo, const char *key_id,
+                                 const buffer_t **hmac_key_r,
                                  const char **error_r)
 {
        const char *base64_key;
-       if (oauth2_validation_key_cache_lookup_hmac_key(set->key_cache, key_id,
+       const char *cache_key_id = t_strconcat(key_id, ".", algo, NULL);
+       if (oauth2_validation_key_cache_lookup_hmac_key(set->key_cache, cache_key_id,
                                                        hmac_key_r) == 0)
                return 0;
        int ret;
-       const char *lookup_key = t_strconcat(DICT_PATH_SHARED, key_id, NULL);
+       const char *lookup_key = t_strconcat(DICT_PATH_SHARED, algo, "/", key_id, NULL);
        /* do a synchronous dict lookup */
        if ((ret = dict_lookup(set->key_dict, pool_datastack_create(),
                               lookup_key, &base64_key, error_r)) < 0) {
                return -1;
        } else if (ret == 0) {
-               *error_r = t_strdup_printf("Key '%s' not found", key_id);
+               *error_r = t_strdup_printf("%s key '%s' not found", algo, key_id);
                return -1;
        }
 
@@ -69,7 +71,7 @@ static int oauth2_lookup_hmac_key(const struct oauth2_settings *set,
                *error_r = "Invalid base64 encoded key";
                return -1;
        }
-       oauth2_validation_key_cache_insert_hmac_key(set->key_cache, key_id, key);
+       oauth2_validation_key_cache_insert_hmac_key(set->key_cache, cache_key_id, key);
        *hmac_key_r = key;
        return 0;
 }
@@ -91,7 +93,7 @@ static int oauth2_validate_hmac(const struct oauth2_settings *set,
        }
 
        const buffer_t *key;
-       if (oauth2_lookup_hmac_key(set, key_id, &key, error_r) < 0)
+       if (oauth2_lookup_hmac_key(set, algo, key_id, &key, error_r) < 0)
                return -1;
        struct hmac_context ctx;
        hmac_init(&ctx, key->data, key->used, method);
@@ -113,20 +115,22 @@ static int oauth2_validate_hmac(const struct oauth2_settings *set,
 }
 
 static int oauth2_lookup_pubkey(const struct oauth2_settings *set,
-                               const char *key_id, struct dcrypt_public_key **key_r,
+                               const char *algo, const char *key_id,
+                               struct dcrypt_public_key **key_r,
                                const char **error_r)
 {
        const char *key_str;
-       if (oauth2_validation_key_cache_lookup_pubkey(set->key_cache, key_id, key_r) == 0)
+       const char *cache_key_id = t_strconcat(key_id, ".", algo, NULL);
+       if (oauth2_validation_key_cache_lookup_pubkey(set->key_cache, cache_key_id, key_r) == 0)
                return 0;
        int ret;
-       const char *lookup_key = t_strconcat(DICT_PATH_SHARED, key_id, NULL);
+       const char *lookup_key = t_strconcat(DICT_PATH_SHARED, algo, "/", key_id, NULL);
        /* do a synchronous dict lookup */
        if ((ret = dict_lookup(set->key_dict, pool_datastack_create(),
                               lookup_key, &key_str, error_r)) < 0) {
                return -1;
        } else if (ret == 0) {
-               *error_r = t_strdup_printf("Key '%s' not found", key_id);
+               *error_r = t_strdup_printf("%s key '%s' not found", algo, key_id);
                return -1;
        }
 
@@ -139,7 +143,7 @@ static int oauth2_lookup_pubkey(const struct oauth2_settings *set,
        }
 
        /* cache key */
-       oauth2_validation_key_cache_insert_pubkey(set->key_cache, key_id, pubkey);
+       oauth2_validation_key_cache_insert_pubkey(set->key_cache, cache_key_id, pubkey);
        *key_r = pubkey;
        return 0;
 }
@@ -185,7 +189,7 @@ static int oauth2_validate_rsa_ecdsa(const struct oauth2_settings *set,
                t_base64url_decode_str(BASE64_DECODE_FLAG_NO_PADDING, blobs[2]);
 
        struct dcrypt_public_key *pubkey;
-       if (oauth2_lookup_pubkey(set, key_id, &pubkey, error_r) < 0)
+       if (oauth2_lookup_pubkey(set, algo, key_id, &pubkey, error_r) < 0)
                return -1;
 
        /* data to verify */
index 31698b3cdc7463e63c52252e9ff70cefc2d7f15f..1dc62cde88ca16498d1c4d25cc6b7e8701aa9f18 100644 (file)
@@ -205,12 +205,12 @@ static buffer_t *create_jwt_token_fields(const char *algo, time_t exp, time_t ia
        return tokenbuf;
 }
 
-#define save_key(key) save_key_to("default", (key))
-static void save_key_to(const char *name, const char *keydata)
+#define save_key(algo, key) save_key_to(algo, "default", (key))
+static void save_key_to(const char *algo, const char *name, const char *keydata)
 {
        const char *error;
        struct dict_transaction_context *ctx = dict_transaction_begin(keys_dict);
-       dict_set(ctx, t_strconcat(DICT_PATH_SHARED, name, NULL), keydata);
+       dict_set(ctx, t_strconcat(DICT_PATH_SHARED, algo, "/", name, NULL), keydata);
        if (dict_transaction_commit(&ctx, &error) < 0)
                i_error("dict_set(%s) failed: %s", name, error);
 }
@@ -468,12 +468,12 @@ static void test_jwt_key_files(void)
        void *ptr = buffer_append_space_unsafe(secret, 32);
        random_fill(ptr, 32);
        buffer_t *b64_key = t_base64_encode(0, (size_t)-1, secret->data, secret->used);
-       save_key_to("first", str_c(b64_key));
+       save_key_to("HS256", "first", str_c(b64_key));
        buffer_t *secret2 = t_buffer_create(32);
        ptr = buffer_append_space_unsafe(secret2, 32);
        random_fill(ptr, 32);
        b64_key = t_base64_encode(0, (size_t)-1, secret2->data, secret2->used);
-       save_key_to("second", str_c(b64_key));
+       save_key_to("HS256", "second", str_c(b64_key));
 
        /* create and sign token */
        buffer_t *token_1 = create_jwt_token_kid("HS256", "first");
@@ -491,7 +491,7 @@ static void test_jwt_key_files(void)
 
        test_assert(parse_jwt_token(&req, str_c(token_3), &is_jwt, &error) != 0);
        test_assert(is_jwt == TRUE);
-       test_assert_strcmp(error, "Key 'missing' not found");
+       test_assert_strcmp(error, "HS256 key 'missing' not found");
        test_assert(parse_jwt_token(&req, str_c(token_4), &is_jwt, &error) != 0);
        test_assert(is_jwt == TRUE);
        test_assert_strcmp(error, "'kid' field is empty");
@@ -508,7 +508,7 @@ static void test_jwt_rs_token(void)
        test_begin("JWT RSA token");
        /* write public key to file */
        oauth2_validation_key_cache_evict(key_cache, "default");
-       save_key(rsa_public_key);
+       save_key("RS256", rsa_public_key);
 
        buffer_t *tokenbuf = create_jwt_token("RS256");
        /* sign token */
@@ -541,7 +541,7 @@ static void test_jwt_ps_token(void)
        test_begin("JWT RSAPSS token");
        /* write public key to file */
        oauth2_validation_key_cache_evict(key_cache, "default");
-       save_key(rsa_public_key);
+       save_key("PS256", rsa_public_key);
 
        buffer_t *tokenbuf = create_jwt_token("PS256");
        /* sign token */
@@ -586,7 +586,7 @@ static void test_jwt_ec_token(void)
                exit(1);
        }
        oauth2_validation_key_cache_evict(key_cache, "default");
-       save_key(str_c(keybuf));
+       save_key("ES256", str_c(keybuf));
 
        buffer_t *tokenbuf = create_jwt_token("ES256");
        /* sign token */
@@ -634,7 +634,7 @@ static void test_do_init(void)
        random_fill(ptr, 32);
        buffer_t *b64_key = t_base64_encode(0, (size_t)-1,
                                            hs_sign_key->data, hs_sign_key->used);
-       save_key(str_c(b64_key));
+       save_key("HS256", str_c(b64_key));
 }
 
 static void test_do_deinit(void)