]> git.ipfire.org Git - thirdparty/dovecot/core.git/commitdiff
lib-oauth2: Add support for validating JWT tokens
authorAki Tuomi <aki.tuomi@open-xchange.com>
Sat, 8 Feb 2020 17:11:43 +0000 (19:11 +0200)
committeraki.tuomi <aki.tuomi@open-xchange.com>
Thu, 20 Feb 2020 11:19:56 +0000 (11:19 +0000)
This adds support for handling JWT tokens without external server.
It supports HS/RS/PS/ES algorithms with SHA-2 hashes.

The validation keys are pulled from specified dict and cached automatically
in memory.

src/lib-oauth2/Makefile.am
src/lib-oauth2/oauth2-jwt.c [new file with mode: 0644]
src/lib-oauth2/oauth2-key-cache.c [new file with mode: 0644]
src/lib-oauth2/oauth2-private.h
src/lib-oauth2/oauth2.h
src/lib-oauth2/test-oauth2-jwt.c [new file with mode: 0644]

index 70e3bb430903d1e3326f315b7a2b35eb32850c2a..aa4d9d98875154ff38576a3a24558c92a8efc1b7 100644 (file)
@@ -2,6 +2,8 @@ AM_CPPFLAGS = \
        -I$(top_srcdir)/src/lib \
        -I$(top_srcdir)/src/lib-test \
        -I$(top_srcdir)/src/lib-http \
+       -I$(top_srcdir)/src/lib-dcrypt \
+       -I$(top_srcdir)/src/lib-dict \
        -I$(top_srcdir)/src/lib-settings
 
 noinst_LTLIBRARIES=liboauth2.la
@@ -15,29 +17,36 @@ noinst_HEADERS = \
 
 liboauth2_la_SOURCES = \
        oauth2.c \
-       oauth2-request.c
+       oauth2-request.c \
+       oauth2-jwt.c \
+       oauth2-key-cache.c
 
 test_programs = \
-       test-oauth2-json
+       test-oauth2-json \
+       test-oauth2-jwt
 
 noinst_PROGRAMS = $(test_programs)
 
 test_libs = \
        $(noinst_LTLIBRARIES) \
+       ../lib-dcrypt/libdcrypt.la \
        ../lib-http/libhttp.la \
        ../lib-dns/libdns.la \
        ../lib-ssl-iostream/libssl_iostream.la \
        ../lib-master/libmaster.la \
+       ../lib-dict/libdict.la \
        ../lib-settings/libsettings.la \
        ../lib-test/libtest.la \
        ../lib/liblib.la \
        $(MODULE_LIBS)
 test_deps = \
        $(noinst_LTLIBRARIES) \
+       ../lib-dcrypt/libdcrypt.la \
        ../lib-http/libhttp.la \
        ../lib-dns/libdns.la \
        ../lib-ssl-iostream/libssl_iostream.la \
        ../lib-master/libmaster.la \
+       ../lib-dict/libdict.la \
        ../lib-settings/libsettings.la \
        ../lib-test/libtest.la \
        ../lib/liblib.la
@@ -46,6 +55,10 @@ test_oauth2_json_SOURCES = test-oauth2-json.c
 test_oauth2_json_LDADD = $(test_libs)
 test_oauth2_json_DEPENDENCIES = $(test_deps)
 
+test_oauth2_jwt_SOURCES = test-oauth2-jwt.c
+test_oauth2_jwt_LDADD = $(test_libs)
+test_oauth2_jwt_DEPENDENCIES = $(test_deps)
+
 check-local:
        for bin in $(test_programs); do \
          if ! $(RUN_TEST) ./$$bin; then exit 1; fi; \
diff --git a/src/lib-oauth2/oauth2-jwt.c b/src/lib-oauth2/oauth2-jwt.c
new file mode 100644 (file)
index 0000000..527a0f3
--- /dev/null
@@ -0,0 +1,392 @@
+/* Copyright (c) 2020 Dovecot authors, see the included COPYING file */
+
+#include "lib.h"
+#include "buffer.h"
+#include "str.h"
+#include "hmac.h"
+#include "array.h"
+#include "hash-method.h"
+#include "istream.h"
+#include "iso8601-date.h"
+#include "json-tree.h"
+#include "array.h"
+#include "base64.h"
+#include "dcrypt.h"
+#include "var-expand.h"
+#include "oauth2.h"
+#include "oauth2-private.h"
+#include "dict.h"
+
+#include <time.h>
+
+static const char *get_field(const struct json_tree *tree, const char *key)
+{
+       const struct json_tree_node *root = json_tree_root(tree);
+       const struct json_tree_node *value_node = json_tree_find_key(root, key);
+       if (value_node == NULL || value_node->value_type == JSON_TYPE_OBJECT ||
+           value_node->value_type == JSON_TYPE_ARRAY)
+               return NULL;
+       return json_tree_get_value_str(value_node);
+}
+
+static int get_time_field(const struct json_tree *tree, const char *key,
+                         long *value_r)
+{
+       const char *value = get_field(tree, key);
+       int tz_offset ATTR_UNUSED;
+       if (value == NULL)
+               return 0;
+       if ((str_to_long(value, value_r) < 0 &&
+            !iso8601_date_parse((const unsigned char*)value, strlen(value),
+                                value_r, &tz_offset)) ||
+           *value_r < 0)
+                return -1;
+       return 1;
+}
+
+static int oauth2_lookup_hmac_key(const struct oauth2_settings *set,
+                                 const char *key_id, const buffer_t **hmac_key_r,
+                                 const char **error_r)
+{
+       const char *base64_key;
+       if (oauth2_validation_key_cache_lookup_hmac_key(set->key_cache, key_id,
+                                                       hmac_key_r) == 0)
+               return 0;
+       int ret;
+       const char *lookup_key = t_strconcat(DICT_PATH_SHARED, key_id, NULL);
+       /* do a synchronous dict lookup */
+       if ((ret = dict_lookup(set->key_dict, pool_datastack_create(),
+                              lookup_key, &base64_key, error_r)) < 0) {
+               return -1;
+       } else if (ret == 0) {
+               *error_r = t_strdup_printf("Key '%s' not found", key_id);
+               return -1;
+       }
+
+       /* decode key */
+       buffer_t *key = t_base64_decode_str(base64_key);
+       if (key->used == 0) {
+               *error_r = "Invalid base64 encoded key";
+               return -1;
+       }
+       oauth2_validation_key_cache_insert_hmac_key(set->key_cache, key_id, key);
+       *hmac_key_r = key;
+       return 0;
+}
+
+static int oauth2_validate_hmac(const struct oauth2_settings *set,
+                               const char *algo, const char *key_id,
+                               const char *const *blobs, const char **error_r)
+{
+       const struct hash_method *method;
+       if (strcmp(algo, "HS256") == 0)
+               method = hash_method_lookup("sha256");
+       else if (strcmp(algo, "HS384") == 0)
+               method = hash_method_lookup("sha384");
+       else if (strcmp(algo, "HS512") == 0)
+               method = hash_method_lookup("sha512");
+       else {
+               *error_r = t_strdup_printf("unsupported algorithm '%s'", algo);
+               return -1;
+       }
+
+       const buffer_t *key;
+       if (oauth2_lookup_hmac_key(set, key_id, &key, error_r) < 0)
+               return -1;
+       struct hmac_context ctx;
+       hmac_init(&ctx, key->data, key->used, method);
+       hmac_update(&ctx, blobs[0], strlen(blobs[0]));
+       hmac_update(&ctx, ".", 1);
+       hmac_update(&ctx, blobs[1], strlen(blobs[1]));
+       unsigned char digest[method->digest_size];
+
+       hmac_final(&ctx, digest);
+
+       buffer_t *their_digest =
+               t_base64url_decode_str(BASE64_DECODE_FLAG_NO_PADDING, blobs[2]);
+       if (method->digest_size != their_digest->used ||
+           memcmp(digest, their_digest->data, method->digest_size) != 0) {
+               *error_r = "Incorrect JWT signature";
+               return -1;
+       }
+       return 0;
+}
+
+static int oauth2_lookup_pubkey(const struct oauth2_settings *set,
+                               const char *key_id, struct dcrypt_public_key **key_r,
+                               const char **error_r)
+{
+       const char *key_str;
+       if (oauth2_validation_key_cache_lookup_pubkey(set->key_cache, key_id, key_r) == 0)
+               return 0;
+       int ret;
+       const char *lookup_key = t_strconcat(DICT_PATH_SHARED, key_id, NULL);
+       /* do a synchronous dict lookup */
+       if ((ret = dict_lookup(set->key_dict, pool_datastack_create(),
+                              lookup_key, &key_str, error_r)) < 0) {
+               return -1;
+       } else if (ret == 0) {
+               *error_r = t_strdup_printf("Key '%s' not found", key_id);
+               return -1;
+       }
+
+       /* try to load key */
+       struct dcrypt_public_key *pubkey;
+       const char *error;
+       if (!dcrypt_key_load_public(&pubkey, key_str, &error)) {
+               *error_r = t_strdup_printf("Cannot load key: %s", error);
+               return -1;
+       }
+
+       /* cache key */
+       oauth2_validation_key_cache_insert_pubkey(set->key_cache, key_id, pubkey);
+       *key_r = pubkey;
+       return 0;
+}
+
+static int oauth2_validate_rsa_ecdsa(const struct oauth2_settings *set,
+                                    const char *algo, const char *key_id,
+                                    const char *const *blobs, const char **error_r)
+{
+       const char *method;
+       enum dcrypt_padding padding;
+       enum dcrypt_signature_format sig_format;
+       if (!dcrypt_is_initialized()) {
+               *error_r = "No crypto library loaded";
+               return -1;
+       }
+
+       if (str_begins(algo, "RS")) {
+               padding = DCRYPT_PADDING_RSA_PKCS1;
+               sig_format = DCRYPT_SIGNATURE_FORMAT_DSS;
+       } else if (str_begins(algo, "PS")) {
+               padding = DCRYPT_PADDING_RSA_PKCS1_PSS;
+               sig_format = DCRYPT_SIGNATURE_FORMAT_DSS;
+       } else if (str_begins(algo, "ES")) {
+               padding = DCRYPT_PADDING_DEFAULT;
+               sig_format = DCRYPT_SIGNATURE_FORMAT_X962;
+       } else {
+               /* this should be checked by caller */
+               i_unreached();
+       }
+
+       if (strcmp(algo+2, "256") == 0) {
+               method = "sha256";
+       } else if (strcmp(algo+2, "384") == 0) {
+               method = "sha384";
+       } else if (strcmp(algo+2, "512") == 0) {
+               method = "sha512";
+       } else {
+               *error_r = t_strdup_printf("Unsupported algorithm '%s'", algo);
+               return -1;
+       }
+
+       buffer_t *signature =
+               t_base64url_decode_str(BASE64_DECODE_FLAG_NO_PADDING, blobs[2]);
+
+       struct dcrypt_public_key *pubkey;
+       if (oauth2_lookup_pubkey(set, key_id, &pubkey, error_r) < 0)
+               return -1;
+
+       /* data to verify */
+       const char *data = t_strconcat(blobs[0], ".", blobs[1], NULL);
+
+       /* verify signature */
+       bool valid;
+       if (!dcrypt_verify(pubkey, method, sig_format, data, strlen(data),
+                          signature->data, signature->used, &valid, padding, error_r)) {
+               valid = FALSE;
+       } else if (!valid) {
+               *error_r = "Bad signature";
+       }
+
+       return valid ? 0 : -1;
+}
+
+static int oauth2_validate_signature(const struct oauth2_settings *set,
+                                    const char *algo, const char *key_id,
+                                    const char *const *blobs, const char **error_r)
+{
+       if (str_begins(algo, "HS"))
+               return oauth2_validate_hmac(set, algo, key_id, blobs, error_r);
+       else if (str_begins(algo, "RS") || str_begins(algo, "PS") ||
+                str_begins(algo, "ES"))
+               return oauth2_validate_rsa_ecdsa(set, algo, key_id, blobs, error_r);
+
+       *error_r = t_strdup_printf("Unsupported algorithm '%s'", algo);
+       return -1;
+}
+
+static void
+oauth2_jwt_copy_fields(ARRAY_TYPE(oauth2_field) *fields, struct json_tree *tree)
+{
+       pool_t pool = array_get_pool(fields);
+       ARRAY(const struct json_tree_node*) nodes;
+       t_array_init(&nodes, 1);
+       const struct json_tree_node *root = json_tree_root(tree);
+       array_push_back(&nodes, &root);
+
+       while (array_count(&nodes) > 0) {
+               const struct json_tree_node *const *pnode = array_front(&nodes);
+               const struct json_tree_node *node = *pnode;
+               array_pop_front(&nodes);
+               while (node != NULL) {
+                       if (node->value_type == JSON_TYPE_OBJECT) {
+                               root = node->value.child;
+                               array_push_back(&nodes, &root);
+                       } else if (node->key != NULL) {
+                               struct oauth2_field *field =
+                                       array_append_space(fields);
+                               field->name = p_strdup(pool, node->key);
+                               field->value = p_strdup(pool, json_tree_get_value_str(node));
+                       }
+                       node = node->next;
+               }
+       }
+}
+
+static int
+oauth2_jwt_header_process(struct json_tree *tree, const char **alg_r,
+                         const char **kid_r, const char **error_r)
+{
+       const char *typ = get_field(tree, "typ");
+       const char *algo = get_field(tree, "alg");
+       const char *kid = get_field(tree, "kid");
+
+       if (null_strcmp(typ, "JWT") != 0) {
+               *error_r = "Cannot find 'typ' field";
+               return -1;
+       }
+
+       if (algo == NULL) {
+               *error_r = "Cannot find 'alg' field";
+               return -1;
+       }
+
+       /* these are lost when tree is deinit */
+       *alg_r = t_strdup(algo);
+       *kid_r = t_strdup(kid);
+       return 0;
+}
+
+static int
+oauth2_jwt_body_process(ARRAY_TYPE(oauth2_field) *fields, struct json_tree *tree,
+                       const char **error_r)
+{
+       const char *sub = get_field(tree, "sub");
+
+       int ret;
+       long t0 = time(NULL);
+       /* default IAT and NBF to now */
+       long iat, nbf, exp;
+       int tz_offset ATTR_UNUSED;
+
+       if (sub == NULL) {
+               *error_r = "Missing 'sub' field";
+               return -1;
+       }
+
+       if ((ret = get_time_field(tree, "exp", &exp)) < 1) {
+               *error_r = t_strdup_printf("%s 'exp' field",
+                               ret == 0 ? "Missing" : "Malformed");
+               return -1;
+       }
+
+       if ((ret = get_time_field(tree, "nbf", &nbf)) < 0) {
+               *error_r = "Malformed 'nbf' field";
+               return -1;
+       } else if (ret == 0)
+               nbf = t0;
+
+       if ((ret = get_time_field(tree, "iat", &iat)) < 0) {
+               *error_r = "Malformed 'iat' field";
+               return -1;
+       } else if (ret == 0)
+               iat = t0;
+
+       if (nbf > t0) {
+               *error_r = "Token is not valid yet";
+               return -1;
+       }
+       if (iat > t0) {
+               *error_r = "Token is issued in future";
+               return -1;
+       }
+       if (exp < t0) {
+               *error_r = "Token has expired";
+               return -1;
+       }
+
+       /* ensure token dates are not conflicting */
+       if (nbf < iat ||
+           exp < iat ||
+           exp < nbf) {
+               *error_r = "Token time values are conflicting";
+               return -1;
+       }
+
+       oauth2_jwt_copy_fields(fields, tree);
+       return 0;
+}
+
+int oauth2_try_parse_jwt(const struct oauth2_settings *set,
+                        const char *token, ARRAY_TYPE(oauth2_field) *fields,
+                        bool *is_jwt_r, const char **error_r)
+{
+       const char *const *blobs = t_strsplit(token, ".");
+       int ret;
+
+       i_assert(set->key_dict != NULL);
+
+       /* we don't know if it's JWT token yet */
+       *is_jwt_r = FALSE;
+
+       if (str_array_length(blobs) != 3) {
+               *error_r = "Not a JWT token";
+               return -1;
+       }
+
+       /* attempt to decode header */
+       buffer_t *header =
+               t_base64url_decode_str(BASE64_DECODE_FLAG_NO_PADDING, blobs[0]);
+
+       if (header->used == 0) {
+               *error_r = "Not a JWT token";
+               return -1;
+       }
+
+       struct json_tree *header_tree;
+       if (oauth2_json_tree_build(header, &header_tree, error_r) < 0)
+               return -1;
+
+       const char *algo, *kid;
+       ret = oauth2_jwt_header_process(header_tree, &algo, &kid, error_r);
+       json_tree_deinit(&header_tree);
+       if (ret < 0)
+               return -1;
+
+       /* it is now assumed to be a JWT token */
+       *is_jwt_r = TRUE;
+
+       if (kid == NULL)
+               kid = "default";
+       else if (*kid == '\0') {
+               *error_r = "'kid' field is empty";
+               return -1;
+       }
+
+       /* from now on, this is considered a JWT token. try to validate signature. */
+       if (oauth2_validate_signature(set, algo, kid, blobs, error_r) < 0)
+               return -1;
+
+       /* then parse the actual body */
+       struct json_tree *body_tree;
+       buffer_t *body =
+               t_base64url_decode_str(BASE64_DECODE_FLAG_NO_PADDING, blobs[1]);
+       if (oauth2_json_tree_build(body, &body_tree, error_r) == -1)
+               return -1;
+       ret = oauth2_jwt_body_process(fields, body_tree, error_r);
+       json_tree_deinit(&body_tree);
+
+       return ret;
+}
diff --git a/src/lib-oauth2/oauth2-key-cache.c b/src/lib-oauth2/oauth2-key-cache.c
new file mode 100644 (file)
index 0000000..1a903ed
--- /dev/null
@@ -0,0 +1,139 @@
+/* Copyright (c) 2020 Dovecot authors, see the included COPYING file */
+
+#include "lib.h"
+#include "array.h"
+#include "llist.h"
+#include "buffer.h"
+#include "hash2.h"
+#include "dcrypt.h"
+#include "oauth2.h"
+#include "oauth2-private.h"
+
+struct oauth2_key_cache_entry {
+       const char *key_id;
+       struct dcrypt_public_key *pubkey;
+       buffer_t *hmac_key;
+       struct oauth2_key_cache_entry *prev, *next;
+};
+
+struct oauth2_validation_key_cache {
+       pool_t pool;
+       struct hash2_table *keys;
+       struct oauth2_key_cache_entry *list_start;
+};
+
+struct oauth2_validation_key_cache *oauth2_validation_key_cache_init(void)
+{
+       pool_t pool = pool_alloconly_create(MEMPOOL_GROWING"oauth2 key cache", 128);
+       struct oauth2_validation_key_cache *cache =
+               p_new(pool, struct oauth2_validation_key_cache, 1);
+       cache->pool = pool;
+       cache->keys = hash2_create(8, sizeof(struct oauth2_key_cache_entry),
+                                  hash2_str_hash, hash2_strcmp, NULL);
+       return cache;
+}
+
+void oauth2_validation_key_cache_deinit(struct oauth2_validation_key_cache **_cache)
+{
+       struct oauth2_validation_key_cache *cache = *_cache;
+       *_cache = NULL;
+       if (cache == NULL)
+               return;
+
+       /* free resources */
+       struct oauth2_key_cache_entry *entry = cache->list_start;
+       while (entry != NULL) {
+               if (entry->pubkey != NULL)
+                       dcrypt_key_unref_public(&entry->pubkey);
+               entry = entry->next;
+       }
+       hash2_destroy(&cache->keys);
+       pool_unref(&cache->pool);
+}
+
+int oauth2_validation_key_cache_lookup_pubkey(struct oauth2_validation_key_cache *cache,
+                                             const char *key_id,
+                                             struct dcrypt_public_key **pubkey_r)
+{
+       if (cache == NULL)
+               return -1;
+       struct oauth2_key_cache_entry *entry = hash2_lookup(cache->keys, key_id);
+       if (entry == NULL || entry->pubkey == NULL)
+               return -1;
+
+       *pubkey_r = entry->pubkey;
+       return 0;
+}
+
+int oauth2_validation_key_cache_lookup_hmac_key(struct oauth2_validation_key_cache *cache,
+                                               const char *key_id,
+                                               const buffer_t **hmac_key_r)
+{
+       if (cache == NULL)
+               return -1;
+       struct oauth2_key_cache_entry *entry = hash2_lookup(cache->keys, key_id);
+       if (entry == NULL || entry->hmac_key == NULL ||
+           entry->hmac_key->used == 0)
+               return -1;
+
+       *hmac_key_r = entry->hmac_key;
+       return 0;
+}
+
+void oauth2_validation_key_cache_insert_pubkey(struct oauth2_validation_key_cache *cache,
+                                              const char *key_id,
+                                              struct dcrypt_public_key *pubkey)
+{
+       if (cache == NULL)
+               return;
+       struct oauth2_key_cache_entry *entry = hash2_lookup(cache->keys, key_id);
+       if (entry != NULL) {
+               dcrypt_key_unref_public(&entry->pubkey);
+               entry->pubkey = pubkey;
+               if (entry->hmac_key != NULL)
+                       buffer_set_used_size(entry->hmac_key, 0);
+               return;
+       }
+       entry = hash2_insert(cache->keys, key_id);
+       entry->key_id = p_strdup(cache->pool, key_id);
+       entry->pubkey = pubkey;
+       DLLIST_PREPEND(&cache->list_start, entry);
+}
+
+void oauth2_validation_key_cache_insert_hmac_key(struct oauth2_validation_key_cache *cache,
+                                                const char *key_id,
+                                                const buffer_t *hmac_key)
+{
+       if (cache == NULL)
+               return;
+       struct oauth2_key_cache_entry *entry = hash2_lookup(cache->keys, key_id);
+       if (entry != NULL) {
+               dcrypt_key_unref_public(&entry->pubkey);
+               if (entry->hmac_key == NULL)
+                       entry->hmac_key = buffer_create_dynamic(cache->pool, hmac_key->used);
+               else
+                       buffer_set_used_size(entry->hmac_key, 0);
+               buffer_append(entry->hmac_key, hmac_key->data, hmac_key->used);
+               return;
+       }
+       entry = hash2_insert(cache->keys, key_id);
+       entry->key_id = p_strdup(cache->pool, key_id);
+       entry->hmac_key = buffer_create_dynamic(cache->pool, hmac_key->used);
+       buffer_append(entry->hmac_key, hmac_key->data, hmac_key->used);
+       DLLIST_PREPEND(&cache->list_start, entry);
+}
+
+int oauth2_validation_key_cache_evict(struct oauth2_validation_key_cache *cache,
+                                     const char *key_id)
+{
+       if (cache == NULL)
+               return -1;
+       struct oauth2_key_cache_entry *entry = hash2_lookup(cache->keys, key_id);
+       if (entry == NULL)
+               return -1;
+       if (entry->pubkey != NULL)
+               dcrypt_key_unref_public(&entry->pubkey);
+       DLLIST_REMOVE(&cache->list_start, entry);
+       hash2_remove(cache->keys, key_id);
+       return 0;
+}
index 9e0b0b9e61f42f3d6704a754706c356c4edaf3e7..83ae5ee56e043a96a579b5f9aef80bc4e7c3c171 100644 (file)
@@ -3,6 +3,7 @@
 #define OAUTH2_PRIVATE_H 1
 
 struct json_tree;
+struct dcrypt_public_key;
 
 struct oauth2_request {
        pool_t pool;
@@ -17,6 +18,7 @@ struct oauth2_request {
        struct timeout *to_delayed_error;
 
        const char *username;
+       const char *key_file_template;
 
        void (*json_parsed_cb)(struct oauth2_request*, bool success,
                               const char *error);
@@ -39,4 +41,16 @@ void oauth2_parse_json(struct oauth2_request *req);
 int oauth2_json_tree_build(const buffer_t *json, struct json_tree **tree_r,
                           const char **error_r);
 
+int oauth2_validation_key_cache_lookup_pubkey(struct oauth2_validation_key_cache *cache,
+                                             const char *key_id,
+                                             struct dcrypt_public_key **pubkey_r);
+int oauth2_validation_key_cache_lookup_hmac_key(struct oauth2_validation_key_cache *cache,
+                                               const char *key_id,
+                                               const buffer_t **hmac_key_r);
+void oauth2_validation_key_cache_insert_pubkey(struct oauth2_validation_key_cache *cache,
+                                              const char *key_id,
+                                              struct dcrypt_public_key *pubkey);
+void oauth2_validation_key_cache_insert_hmac_key(struct oauth2_validation_key_cache *cache,
+                                                const char *key_id,
+                                                const buffer_t *hmac_key);
 #endif
index aaacfeb20929525fe8aa8bc0954cc9a0cd91f791..d65152828492e84cbf88c29963c4b9da33028a92 100644 (file)
@@ -4,7 +4,9 @@
 
 #include "net.h"
 
+struct dict;
 struct oauth2_request;
+struct oauth2_validation_key_cache;
 
 struct oauth2_field {
        const char *name;
@@ -30,6 +32,11 @@ struct oauth2_settings {
        const char *client_secret;
        /* access request scope for oauth2 server (optional) */
        const char *scope;
+       /* key dict for looking up validation keys */
+       struct dict *key_dict;
+       /* cache for validation keys */
+       struct oauth2_validation_key_cache *key_cache;
+
        enum {
                INTROSPECTION_MODE_GET_AUTH,
                INTROSPECTION_MODE_GET,
@@ -114,4 +121,18 @@ oauth2_refresh_start(const struct oauth2_settings *set,
 /* abort without calling callback, use this to cancel the request */
 void oauth2_request_abort(struct oauth2_request **);
 
+int oauth2_try_parse_jwt(const struct oauth2_settings *set,
+                        const char *token, ARRAY_TYPE(oauth2_field) *fields,
+                        bool *is_jwt_r, const char **error_r);
+
+/* Initialize validation key cache */
+struct oauth2_validation_key_cache *oauth2_validation_key_cache_init(void);
+
+/* Evict given key ID from cache, returns 0 on successful eviction */
+int oauth2_validation_key_cache_evict(struct oauth2_validation_key_cache *cache,
+                                     const char *key_id);
+
+/* Deinitialize validation key cache */
+void oauth2_validation_key_cache_deinit(struct oauth2_validation_key_cache **_cache);
+
 #endif
diff --git a/src/lib-oauth2/test-oauth2-jwt.c b/src/lib-oauth2/test-oauth2-jwt.c
new file mode 100644 (file)
index 0000000..7685cb2
--- /dev/null
@@ -0,0 +1,658 @@
+/* Copyright (c) 2020 Dovecot authors, see the included COPYING file */
+
+#include "lib.h"
+#include "buffer.h"
+#include "str.h"
+#include "ostream.h"
+#include "hmac.h"
+#include "sha2.h"
+#include "base64.h"
+#include "randgen.h"
+#include "array.h"
+#include "json-parser.h"
+#include "iso8601-date.h"
+#include "oauth2.h"
+#include "oauth2-private.h"
+#include "dcrypt.h"
+#include "dict.h"
+#include "dict-private.h"
+#include "test-common.h"
+#include "unlink-directory.h"
+
+#include <sys/stat.h>
+#include <sys/types.h>
+
+#define base64url_encode_str(str, dest) \
+       base64url_encode(BASE64_ENCODE_FLAG_NO_PADDING, (size_t)-1, (str), \
+                        strlen((str)), (dest))
+
+/**
+ * Test keypair used only for this test.
+ */
+static const char *rsa_public_key =
+"-----BEGIN PUBLIC KEY-----\n"
+"MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAnzyis1ZjfNB0bBgKFMSv\n"
+"vkTtwlvBsaJq7S5wA+kzeVOVpVWwkWdVha4s38XM/pa/yr47av7+z3VTmvDRyAHc\n"
+"aT92whREFpLv9cj5lTeJSibyr/Mrm/YtjCZVWgaOYIhwrXwKLqPr/11inWsAkfIy\n"
+"tvHWTxZYEcXLgAXFuUuaS3uF9gEiNQwzGTU1v0FqkqTBr4B8nW3HCN47XUu0t8Y0\n"
+"e+lf4s4OxQawWD79J9/5d3Ry0vbV3Am1FtGJiJvOwRsIfVChDpYStTcHTCMqtvWb\n"
+"V6L11BWkpzGXSW4Hv43qa+GSYOD2QU68Mb59oSk2OB+BtOLpJofmbGEGgvmwyCI9\n"
+"MwIDAQAB\n"
+"-----END PUBLIC KEY-----";
+static const char *rsa_private_key =
+"-----BEGIN PRIVATE KEY-----\n"
+"MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCfPKKzVmN80HRs\n"
+"GAoUxK++RO3CW8GxomrtLnAD6TN5U5WlVbCRZ1WFrizfxcz+lr/Kvjtq/v7PdVOa\n"
+"8NHIAdxpP3bCFEQWku/1yPmVN4lKJvKv8yub9i2MJlVaBo5giHCtfAouo+v/XWKd\n"
+"awCR8jK28dZPFlgRxcuABcW5S5pLe4X2ASI1DDMZNTW/QWqSpMGvgHydbccI3jtd\n"
+"S7S3xjR76V/izg7FBrBYPv0n3/l3dHLS9tXcCbUW0YmIm87BGwh9UKEOlhK1NwdM\n"
+"Iyq29ZtXovXUFaSnMZdJbge/jepr4ZJg4PZBTrwxvn2hKTY4H4G04ukmh+ZsYQaC\n"
+"+bDIIj0zAgMBAAECggEAKIBGrbCSW2O1yOyQW9nvDUkA5EdsS58Q7US7bvM4iWpu\n"
+"DIBwCXur7/VuKnhn/HUhURLzj/JNozynSChqYyG+CvL+ZLy82LUE3ZIBkSdv/vFL\n"
+"Ft+VvvRtf1EcsmoqenkZl7aN7HD7DJeXBoz5tyVQKuH17WW0fsi9StGtCcUl+H6K\n"
+"zV9Gif0Kj0uLQbCg3THRvKuueBTwCTdjoP0PwaNADgSWb3hJPeLMm/yII4tIMGbO\n"
+"w+xd9wJRl+ZN9nkNtQMxszFGdKjedB6goYLQuP0WRZx+YtykaVJdM75bDUvsQar4\n"
+"9Pc21Fp7UVk/CN11DX/hX3TmTJAUtqYADliVKkTbCQKBgQDLU48tBxm3g1CdDM/P\n"
+"ZIEmpA3Y/m7e9eX7M1Uo/zDh4G/S9a4kkX6GQY2dLFdCtOS8M4hR11Io7MceBKDi\n"
+"djorTZ5zJPQ8+b9Rm+1GlaucGNwRW0cQk2ltT2ksPmJnQn2xvM9T8vE+a4A/YGzw\n"
+"mZOfpoVGykWs/tbSzU2aTaOybQKBgQDIfRf6OmirGPh59l+RSuDkZtISF/51mCV/\n"
+"S1M4DltWDwhjC2Y2T+meIsb/Mjtz4aVNz0EHB8yvn0TMGr94Uwjv4uBdpVSwz+xL\n"
+"hHL7J4rpInH+i0gxa0N+rGwsPwI8wJG95wLY+Kni5KCuXQw55uX1cqnnsahpRZFZ\n"
+"EerBXhjqHwKBgBmEjiaHipm2eEqNjhMoOPFBi59dJ0sCL2/cXGa9yEPA6Cfgv49F\n"
+"V0zAM2azZuwvSbm4+fXTgTMzrDW/PPXPArPmlOk8jQ6OBY3XdOrz48q+b/gZrYyO\n"
+"A6A9ZCSyW6U7+gxxds/BYLeFxF2v21xC2f0iZ/2faykv/oQMUh34en/tAoGACqVZ\n"
+"2JexZyR0TUWf3X80YexzyzIq+OOTWicNzDQ29WLm9xtr2gZ0SUlfd72bGpQoyvDu\n"
+"awkm/UxfwtbIxALkvpg1gcN9s8XWrkviLyPyZF7H3tRWiQlBFEDjnZXa8I7pLkRO\n"
+"Cmdp3fp17cxTEeAI5feovfzZDH39MdWZuZrdh9ECgYBTEv8S7nK8wrxIC390kroV\n"
+"52eBwzckQU2mWa0thUtaGQiU1EYPCSDcjkrLXwB72ft0dW57KyWtvrB6rt1ORgOL\n"
+"eI5hFbwdGQhCHTrAR1vG3SyFPMAm+8JB+sGOD/fvjtZKx//MFNweKFNEF0C/o6Z2\n"
+"FXj90PlgF8sCQut36ZfuIQ==\n"
+"-----END PRIVATE KEY-----";
+
+static buffer_t *hs_sign_key = NULL;
+
+static struct dict *keys_dict = NULL;
+
+static bool skip_dcrypt = FALSE;
+
+static struct oauth2_validation_key_cache *key_cache = NULL;
+
+static int parse_jwt_token(struct oauth2_request *req, const char *token,
+                          bool *is_jwt_r, const char **error_r)
+{
+       struct oauth2_settings set;
+       set.scope = "mail";
+       set.key_dict = keys_dict;
+       set.key_cache = key_cache;
+       i_zero(req);
+       req->pool = pool_datastack_create();
+       req->set = &set;
+       t_array_init(&req->fields, 8);
+       return oauth2_try_parse_jwt(&set, token, &req->fields, is_jwt_r, error_r);
+}
+
+static void test_jwt_token(const char *token)
+{
+       /* then see what the parser likes it */
+       struct oauth2_request req;
+       const char *error = NULL;
+       bool is_jwt;
+       test_assert(parse_jwt_token(&req, token, &is_jwt, &error) == 0);
+       test_assert(is_jwt == TRUE);
+       test_assert(error == NULL);
+
+       /* check fields */
+       test_assert(array_is_created(&req.fields));
+       if (array_is_created(&req.fields)) {
+               const struct oauth2_field *field;
+               bool got_sub = FALSE;
+               array_foreach(&req.fields, field) {
+                       if (strcmp(field->name, "sub") == 0) {
+                               test_assert_strcmp(field->value, "testuser");
+                               got_sub = TRUE;
+                       }
+               }
+               test_assert(got_sub == TRUE);
+       }
+
+       if (error != NULL)
+               i_error("%s", error);
+}
+
+static buffer_t *create_jwt_token_kid(const char *algo, const char *kid)
+{
+       /* make a token */
+       buffer_t *tokenbuf = t_buffer_create(64);
+
+       /* header */
+       base64url_encode_str(t_strdup_printf(
+                               "{\"alg\":\"%s\",\"typ\":\"JWT\",\"kid\":\"%s\"}",
+                                algo, kid), tokenbuf);
+       buffer_append(tokenbuf, ".", 1);
+
+       /* body */
+       base64url_encode_str(t_strdup_printf("{\"sub\":\"testuser\","\
+                               "\"iat\":%"PRIdTIME_T","
+                               "\"exp\":%"PRIdTIME_T"}",
+                                       time(NULL),
+                                       time(NULL)+600), tokenbuf);
+       return tokenbuf;
+}
+
+static buffer_t *create_jwt_token(const char *algo)
+{
+       /* make a token */
+       buffer_t *tokenbuf = t_buffer_create(64);
+
+       /* header */
+       base64url_encode_str(t_strdup_printf(
+                               "{\"alg\":\"%s\",\"typ\":\"JWT\"}", algo), tokenbuf);
+       buffer_append(tokenbuf, ".", 1);
+
+       /* body */
+       base64url_encode_str(t_strdup_printf("{\"sub\":\"testuser\","\
+                               "\"iat\":%"PRIdTIME_T","
+                               "\"exp\":%"PRIdTIME_T"}",
+                                       time(NULL),
+                                       time(NULL)+600), tokenbuf);
+       return tokenbuf;
+}
+
+static void append_key_value(string_t *dest, const char *key, const char *value, bool str)
+{
+       str_append_c(dest, '"');
+       json_append_escaped(dest, key);
+       str_append(dest, "\":");
+       if (str)
+               str_append_c(dest, '"');
+       json_append_escaped(dest, value);
+       if (str)
+               str_append_c(dest, '"');
+
+}
+
+static buffer_t *create_jwt_token_fields(const char *algo, time_t exp, time_t iat,
+                                        time_t nbf, ARRAY_TYPE(oauth2_field) *fields)
+{
+       const struct oauth2_field *field;
+       buffer_t *tokenbuf = t_buffer_create(64);
+       base64url_encode_str(t_strdup_printf(
+                               "{\"alg\":\"%s\",\"typ\":\"JWT\"}", algo), tokenbuf);
+       buffer_append(tokenbuf, ".", 1);
+       string_t *bodybuf = t_str_new(64);
+       str_append_c(bodybuf, '{');
+       if (exp > 0) {
+               append_key_value(bodybuf, "exp", dec2str(exp), FALSE);
+       }
+       if (iat > 0) {
+               if (exp > 0)
+                       str_append_c(bodybuf, ',');
+               append_key_value(bodybuf, "iat", dec2str(iat), FALSE);
+       }
+       if (nbf > 0) {
+               if (exp > 0 || iat > 0)
+                       str_append_c(bodybuf, ',');
+               append_key_value(bodybuf, "nbf", dec2str(nbf), FALSE);
+       }
+       array_foreach(fields, field) {
+               if (str_data(bodybuf)[bodybuf->used-1] != '{')
+                       str_append_c(bodybuf, ',');
+               append_key_value(bodybuf, field->name, field->value, TRUE);
+       }
+       str_append_c(bodybuf, '}');
+       base64url_encode_str(str_c(bodybuf), tokenbuf);
+
+       return tokenbuf;
+}
+
+#define save_key(key) save_key_to("default", (key))
+static void save_key_to(const char *name, const char *keydata)
+{
+       const char *error;
+       struct dict_transaction_context *ctx = dict_transaction_begin(keys_dict);
+       dict_set(ctx, t_strconcat(DICT_PATH_SHARED, name, NULL), keydata);
+       if (dict_transaction_commit(&ctx, &error) < 0)
+               i_error("dict_set(%s) failed: %s", name, error);
+}
+
+static void sign_jwt_token_hs256(buffer_t *tokenbuf, buffer_t *key)
+{
+       i_assert(key != NULL);
+       buffer_t *sig = t_hmac_buffer(&hash_method_sha256, key->data, key->used,
+                                     tokenbuf);
+       buffer_append(tokenbuf, ".", 1);
+       base64url_encode(BASE64_ENCODE_FLAG_NO_PADDING, (size_t)-1,
+                        sig->data, sig->used, tokenbuf);
+}
+
+static void test_jwt_hs_token(void)
+{
+       test_begin("JWT HMAC token");
+
+       /* make a token */
+       buffer_t *tokenbuf = create_jwt_token("HS256");
+       /* sign it */
+       sign_jwt_token_hs256(tokenbuf, hs_sign_key);
+       test_jwt_token(str_c(tokenbuf));
+
+       test_end();
+}
+
+static void test_jwt_broken_token(void)
+{
+       struct test_cases {
+               const char *token;
+               bool is_jwt;
+       } test_cases[] = {
+               { /* empty token */
+                       .token = "",
+                       .is_jwt = FALSE
+               },
+               { /* not base64 */
+                       .token = "{\"alg\":\"HS256\":\"typ\":\"JWT\"}",
+                       .is_jwt = FALSE
+               },
+               { /* not jwt */
+                       .token = "aGVsbG8sIHdvcmxkCg",
+                       .is_jwt = FALSE
+               },
+               { /* no alg field */
+                       .token = "eyJ0eXAiOiAiSldUIn0",
+                       .is_jwt = FALSE
+               },
+               { /* no typ field */
+                       .token = "eyJhbGciOiAiSFMyNTYifQ",
+                       .is_jwt = FALSE
+               },
+               { /* typ field is wrong */
+                       .token = "eyJ0eXAiOiAiand0IiwgImFsZyI6ICJIUzI1NiJ9."
+                                "eyJhbGdvIjogIldURiIsICJ0eXAiOiAiSldUIn0."
+                                "q2wwwWWJVJxqw-J3uQ0DdlIyWfoZ7Z0QrdzvMW_B-jo",
+                       .is_jwt = FALSE
+               },
+               { /* unknown algorithm */
+                       .token = "eyJ0eXAiOiAiSldUIiwgImFsZyI6ICJXVEYifQ."
+                                "eyJhbGdvIjogIldURiIsICJ0eXAiOiAiSldUIn0."
+                                "q2wwwWWJVJxqw-J3uQ0DdlIyWfoZ7Z0QrdzvMW_B-jo",
+                       .is_jwt = TRUE
+               },
+               { /* truncated base64 */
+                       .token  = "yJhbGciOiJIUzI1NiIsInR5",
+                       .is_jwt = FALSE
+               },
+               { /* missing body and signature */
+                       .token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9",
+                       .is_jwt = FALSE
+               },
+               { /* empty body and signature */
+                       .token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..",
+                       .is_jwt = TRUE
+               },
+               { /* empty signature */
+                       .token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9."
+                                "eyJleHAiOjE1ODEzMzA3OTN9.",
+                       .is_jwt = TRUE
+               },
+               { /* bad signature */
+                       .token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9."
+                                "eyJleHAiOjE1ODEzMzA3OTN9."
+                                "q2wwwWWJVJxqw-J3uQ0DdlIyWfoZ7Z0QrdzvMW_B-jo",
+                       .is_jwt = TRUE
+               },
+       };
+
+       test_begin("JWT broken tokens");
+
+       for (size_t i = 0; i < N_ELEMENTS(test_cases); i++) T_BEGIN {
+               struct test_cases *test_case = &test_cases[i];
+               struct oauth2_request req;
+               const char *error = NULL;
+               bool is_jwt;
+               test_assert_idx(parse_jwt_token(&req, test_case->token, &is_jwt, &error) != 0, i);
+               test_assert_idx(test_case->is_jwt == is_jwt, i);
+               test_assert_idx(error != NULL, i);
+       } T_END;
+
+       test_end();
+}
+
+static void test_jwt_bad_valid_token(void)
+{
+       test_begin("JWT bad token tests");
+       time_t now = time(NULL);
+
+       struct test_cases {
+               time_t exp;
+               time_t iat;
+               time_t nbf;
+               const char *key_values[20];
+               const char *error;
+       } test_cases[] =
+       {
+               { /* "empty" token */
+                       .exp = 0,
+                       .iat = 0,
+                       .nbf = 0,
+                       .key_values = { NULL },
+                       .error = "Missing 'sub' field",
+               },
+               { /* missing sub field */
+                       .exp = now+500,
+                       .iat = 0,
+                       .nbf = 0,
+                       .key_values = { NULL },
+                       .error = "Missing 'sub' field",
+               },
+               { /* non-ISO date as iat */
+                       .exp = now+500,
+                       .iat = 0,
+                       .nbf = 0,
+                       .key_values = { "sub", "testuser", "iat", "1.1.2019 16:00", NULL },
+                       .error = "Malformed 'iat' field"
+               },
+               { /* expired token */
+                       .exp = now-500,
+                       .iat = 0,
+                       .nbf = 0,
+                       .key_values = { "sub", "testuser", NULL },
+                       .error = "Token has expired",
+               },
+               { /* future token */
+                       .exp = now+1000,
+                       .iat = now+500,
+                       .nbf = 0,
+                       .key_values = { "sub", "testuser", NULL },
+                       .error = "Token is issued in future",
+               },
+               { /* token not valid yet */
+                       .exp = now+500,
+                       .iat = now,
+                       .nbf = now+250,
+                       .key_values = { "sub", "testuser", NULL },
+                       .error = "Token is not valid yet",
+               },
+       };
+
+       for (size_t i = 0; i < N_ELEMENTS(test_cases); i++) T_BEGIN {
+               const struct test_cases *test_case = &test_cases[i];
+               const char *key = NULL;
+               ARRAY_TYPE(oauth2_field) fields;
+               t_array_init(&fields, 8);
+               for (const char *const *value = test_case->key_values; *value != NULL; value++) {
+                       if (key == NULL) {
+                               key = *value;
+                       } else {
+                               struct oauth2_field *field =
+                                       array_append_space(&fields);
+                               field->name = key;
+                               field->value = *value;
+                               key = NULL;
+                       }
+               }
+
+               buffer_t *tokenbuf =
+                       create_jwt_token_fields("HS256", test_case->exp, test_case->iat,
+                                               test_case->nbf, &fields);
+               sign_jwt_token_hs256(tokenbuf, hs_sign_key);
+               struct oauth2_request req;
+               const char *error = NULL;
+               bool is_jwt;
+               test_assert_idx(parse_jwt_token(&req, str_c(tokenbuf), &is_jwt, &error) != 0, i);
+               test_assert_idx(is_jwt == TRUE, i);
+               if (test_case->error != NULL) {
+                       test_assert_strcmp(test_case->error, error);
+               }
+               test_assert(error != NULL);
+       } T_END;
+
+       test_end();
+}
+
+static void test_jwt_dates(void)
+{
+       test_begin("JWT Token dates");
+
+       /* simple check to make sure ISO8601 dates work too */
+       ARRAY_TYPE(oauth2_field) fields;
+       t_array_init(&fields, 8);
+       struct oauth2_field *field;
+       struct tm tm_b;
+       struct tm *tm;
+       time_t now = time(NULL);
+       time_t exp = now+500;
+       time_t nbf = now-250;
+       time_t iat = now-500;
+
+       field = array_append_space(&fields);
+       field->name = "sub";
+       field->value = "testuser";
+       field = array_append_space(&fields);
+       field->name = "exp";
+       tm = gmtime_r(&exp, &tm_b);
+       field->value = iso8601_date_create_tm(tm, INT_MAX);
+       field = array_append_space(&fields);
+       field->name = "nbf";
+       tm = gmtime_r(&nbf, &tm_b);
+       field->value = iso8601_date_create_tm(tm, INT_MAX);
+       field = array_append_space(&fields);
+       field->name = "iat";
+       tm = gmtime_r(&iat, &tm_b);
+       field->value = iso8601_date_create_tm(tm, INT_MAX);
+       buffer_t *tokenbuf = create_jwt_token_fields("HS256", 0, 0, 0, &fields);
+       sign_jwt_token_hs256(tokenbuf, hs_sign_key);
+       test_jwt_token(str_c(tokenbuf));
+
+       test_end();
+}
+
+static void test_jwt_key_files(void)
+{
+       test_begin("JWT key id");
+       /* write HMAC secrets */
+       struct oauth2_request req;
+       bool is_jwt;
+       const char *error = NULL;
+
+       buffer_t *secret = t_buffer_create(32);
+       void *ptr = buffer_append_space_unsafe(secret, 32);
+       random_fill(ptr, 32);
+       buffer_t *b64_key = t_base64_encode(0, (size_t)-1, secret->data, secret->used);
+       save_key_to("first", str_c(b64_key));
+       buffer_t *secret2 = t_buffer_create(32);
+       ptr = buffer_append_space_unsafe(secret2, 32);
+       random_fill(ptr, 32);
+       b64_key = t_base64_encode(0, (size_t)-1, secret2->data, secret2->used);
+       save_key_to("second", str_c(b64_key));
+
+       /* create and sign token */
+       buffer_t *token_1 = create_jwt_token_kid("HS256", "first");
+       buffer_t *token_2 = create_jwt_token_kid("HS256", "second");
+       buffer_t *token_3 = create_jwt_token_kid("HS256", "missing");
+       buffer_t *token_4 = create_jwt_token_kid("HS256", "");
+
+       sign_jwt_token_hs256(token_1, secret);
+       sign_jwt_token_hs256(token_2, secret2);
+       sign_jwt_token_hs256(token_3, secret);
+       sign_jwt_token_hs256(token_4, secret);
+
+       test_jwt_token(str_c(token_1));
+       test_jwt_token(str_c(token_2));
+
+       test_assert(parse_jwt_token(&req, str_c(token_3), &is_jwt, &error) != 0);
+       test_assert(is_jwt == TRUE);
+       test_assert_strcmp(error, "Key 'missing' not found");
+       test_assert(parse_jwt_token(&req, str_c(token_4), &is_jwt, &error) != 0);
+       test_assert(is_jwt == TRUE);
+       test_assert_strcmp(error, "'kid' field is empty");
+
+       test_end();
+}
+
+static void test_jwt_rs_token(void)
+{
+       const char *error;
+       if (skip_dcrypt)
+               return;
+
+       test_begin("JWT RSA token");
+       /* write public key to file */
+       oauth2_validation_key_cache_evict(key_cache, "default");
+       save_key(rsa_public_key);
+
+       buffer_t *tokenbuf = create_jwt_token("RS256");
+       /* sign token */
+       buffer_t *sig = t_buffer_create(64);
+       struct dcrypt_private_key *key;
+       if (!dcrypt_key_load_private(&key, rsa_private_key, NULL, NULL, &error) ||
+           !dcrypt_sign(key, "sha256", DCRYPT_SIGNATURE_FORMAT_DSS,
+                        tokenbuf->data, tokenbuf->used, sig,
+                        DCRYPT_PADDING_RSA_PKCS1, &error)) {
+               i_error("dcrypt signing failed: %s", error);
+               exit(1);
+       }
+       dcrypt_key_unref_private(&key);
+       /* convert to base64 */
+       buffer_append(tokenbuf, ".", 1);
+       base64url_encode(BASE64_ENCODE_FLAG_NO_PADDING, (size_t)-1,
+                        sig->data, sig->used, tokenbuf);
+
+       test_jwt_token(str_c(tokenbuf));
+
+       test_end();
+}
+
+static void test_jwt_ps_token(void)
+{
+       const char *error;
+       if (skip_dcrypt)
+               return;
+
+       test_begin("JWT RSAPSS token");
+       /* write public key to file */
+       oauth2_validation_key_cache_evict(key_cache, "default");
+       save_key(rsa_public_key);
+
+       buffer_t *tokenbuf = create_jwt_token("PS256");
+       /* sign token */
+       buffer_t *sig = t_buffer_create(64);
+       struct dcrypt_private_key *key;
+       if (!dcrypt_key_load_private(&key, rsa_private_key, NULL, NULL, &error) ||
+           !dcrypt_sign(key, "sha256", DCRYPT_SIGNATURE_FORMAT_DSS,
+                        tokenbuf->data, tokenbuf->used, sig,
+                        DCRYPT_PADDING_RSA_PKCS1_PSS, &error)) {
+               i_error("dcrypt signing failed: %s", error);
+               exit(1);
+       }
+       dcrypt_key_unref_private(&key);
+       /* convert to base64 */
+       buffer_append(tokenbuf, ".", 1);
+       base64url_encode(BASE64_ENCODE_FLAG_NO_PADDING, (size_t)-1,
+                        sig->data, sig->used, tokenbuf);
+
+       test_jwt_token(str_c(tokenbuf));
+
+       test_end();
+}
+
+static void test_jwt_ec_token(void)
+{
+       const char *error;
+       if (skip_dcrypt)
+               return;
+
+       test_begin("JWT ECDSA token");
+       struct dcrypt_keypair pair;
+       i_zero(&pair);
+       if (!dcrypt_keypair_generate(&pair, DCRYPT_KEY_EC, 0,
+                                    "prime256v1", &error)) {
+               i_error("dcrypt keypair generate failed: %s", error);
+               exit(1);
+       }
+       /* export public key */
+       buffer_t *keybuf = t_buffer_create(256);
+       if (!dcrypt_key_store_public(pair.pub, DCRYPT_FORMAT_PEM, keybuf, &error)) {
+               i_error("dcrypt key store failed: %s", error);
+               exit(1);
+       }
+       oauth2_validation_key_cache_evict(key_cache, "default");
+       save_key(str_c(keybuf));
+
+       buffer_t *tokenbuf = create_jwt_token("ES256");
+       /* sign token */
+       buffer_t *sig = t_buffer_create(64);
+       if (!dcrypt_sign(pair.priv, "sha256", DCRYPT_SIGNATURE_FORMAT_X962,
+                        tokenbuf->data, tokenbuf->used, sig,
+                        DCRYPT_PADDING_DEFAULT, &error)) {
+               i_error("dcrypt signing failed: %s", error);
+               exit(1);
+       }
+       dcrypt_keypair_unref(&pair);
+       /* convert to base64 */
+       buffer_append(tokenbuf, ".", 1);
+       base64url_encode(BASE64_ENCODE_FLAG_NO_PADDING, (size_t)-1,
+                        sig->data, sig->used, tokenbuf);
+       test_jwt_token(str_c(tokenbuf));
+
+       test_end();
+}
+
+static void test_do_init(void)
+{
+       const char *error;
+       struct dcrypt_settings dcrypt_set = {
+               .module_dir = "../lib-dcrypt/.libs",
+       };
+       struct dict_settings dict_set = {
+               .username = "testuser",
+               .value_type = DICT_DATA_TYPE_STRING,
+               .base_dir = ".",
+       };
+       i_unlink_if_exists(".keys");
+       dict_driver_register(&dict_driver_file);
+       if (dict_init("file:.keys", &dict_set, &keys_dict, &error) < 0)
+               i_fatal("dict_init(file:.keys): %s", error);
+       if (!dcrypt_initialize(NULL, &dcrypt_set, &error)) {
+               i_error("No functional dcrypt backend found - "
+                       "skipping some tests: %s", error);
+               skip_dcrypt = TRUE;
+       }
+       key_cache = oauth2_validation_key_cache_init();
+       /* write HMAC secret */
+       hs_sign_key =buffer_create_dynamic(default_pool, 32);
+       void *ptr = buffer_append_space_unsafe(hs_sign_key, 32);
+       random_fill(ptr, 32);
+       buffer_t *b64_key = t_base64_encode(0, (size_t)-1,
+                                           hs_sign_key->data, hs_sign_key->used);
+       save_key(str_c(b64_key));
+}
+
+static void test_do_deinit(void)
+{
+       dict_deinit(&keys_dict);
+       dict_driver_unregister(&dict_driver_file);
+       oauth2_validation_key_cache_deinit(&key_cache);
+       i_unlink(".keys");
+       buffer_free(&hs_sign_key);
+       dcrypt_deinitialize();
+}
+
+int main(void)
+{
+       static void (*test_functions[])(void) = {
+               test_do_init,
+               test_jwt_hs_token,
+               test_jwt_bad_valid_token,
+               test_jwt_broken_token,
+               test_jwt_dates,
+               test_jwt_key_files,
+               test_jwt_rs_token,
+               test_jwt_ps_token,
+               test_jwt_ec_token,
+               test_do_deinit,
+               NULL
+       };
+       int ret;
+       ret = test_run(test_functions);
+       return ret;
+}
+