]> git.ipfire.org Git - thirdparty/dovecot/core.git/commitdiff
auth: mech-scram - Split off core implementation for server-side SCRAM-SHA-* authenti...
authorStephan Bosch <stephan.bosch@open-xchange.com>
Mon, 26 Sep 2022 19:42:46 +0000 (21:42 +0200)
committeraki.tuomi <aki.tuomi@open-xchange.com>
Fri, 27 Jan 2023 09:34:54 +0000 (09:34 +0000)
src/auth/Makefile.am
src/auth/auth-scram-server.c
src/auth/auth-scram-server.h [new file with mode: 0644]
src/auth/auth-scram.h
src/auth/mech-scram.c

index 3b95dfdeb0614a784b2d45b4f6a17e57ef13933e..36a1f3e1253d429e94eebd9c0a14e4821e765867 100644 (file)
@@ -64,6 +64,7 @@ libpassword_la_SOURCES = \
        crypt-blowfish.c \
        mycrypt.c \
        auth-scram.c \
+       auth-scram-server.c \
        password-scheme.c \
        password-scheme-crypt.c \
        password-scheme-md5crypt.c \
@@ -183,6 +184,7 @@ headers = \
        passdb-cache.h \
        passdb-template.h \
        auth-scram.h \
+       auth-scram-server.h \
        password-scheme.h \
        userdb.h \
        userdb-blocking.h \
index 8b6153e1c121f1458036f1c1ad468f8762b263f4..3ed3953aaef6cde15001017ab53b736a9be336cd 100644 (file)
@@ -1,3 +1,77 @@
+/*
+ * SCRAM-SHA-1 SASL authentication, see RFC-5802
+ *
+ * Copyright (c) 2011-2016 Florian Zeitz <florob@babelmonkeys.de>
+ * Copyright (c) 2022-2023 Dovecot Oy
+ *
+ * This software is released under the MIT license.
+ */
+
+#include "lib.h"
+#include "base64.h"
+#include "buffer.h"
+#include "hmac.h"
+#include "randgen.h"
+#include "safe-memset.h"
+#include "str.h"
+#include "strfuncs.h"
+#include "strnum.h"
+
+#include "auth-scram.h"
+#include "auth-scram-server.h"
+
+/* s-nonce length */
+#define SCRAM_SERVER_NONCE_LEN 64
+
+static bool
+auth_scram_server_set_username(struct auth_scram_server *server,
+                              const char *username, const char **error_r)
+{
+       return server->backend->set_username(server, username, error_r);
+}
+static bool
+auth_scram_server_set_login_username(struct auth_scram_server *server,
+                                    const char *username, const char **error_r)
+{
+       return server->backend->set_login_username(server, username,
+                                                  error_r);
+}
+
+static int
+auth_scram_server_credentials_lookup(struct auth_scram_server *server)
+{
+       const struct hash_method *hmethod = server->hash_method;
+       struct auth_scram_key_data *kdata = &server->key_data;
+       pool_t pool = server->pool;
+
+       i_zero(kdata);
+       kdata->pool = pool;
+       kdata->hmethod = hmethod;
+       kdata->stored_key = p_malloc(pool, hmethod->digest_size);
+       kdata->server_key = p_malloc(pool, hmethod->digest_size);
+
+       return server->backend->credentials_lookup(server, kdata);
+}
+
+void auth_scram_server_init(struct auth_scram_server *server_r, pool_t pool,
+                           const struct hash_method *hmethod,
+                           const struct auth_scram_server_backend *backend)
+{
+       pool_ref(pool);
+
+       i_zero(server_r);
+       server_r->pool = pool;
+       server_r->hash_method = hmethod;
+
+       server_r->backend = backend;
+}
+
+void auth_scram_server_deinit(struct auth_scram_server *server)
+{
+       i_assert(server->hash_method != NULL);
+       pool_unref(&server->pool);
+}
+
 static const char *auth_scram_unescape_username(const char *in)
 {
        string_t *out;
@@ -30,8 +104,10 @@ static const char *auth_scram_unescape_username(const char *in)
 }
 
 static int
-auth_scram_parse_client_first(struct scram_auth_request *server,
+auth_scram_parse_client_first(struct auth_scram_server *server,
                              const unsigned char *data, size_t size,
+                             const char **username_r,
+                             const char **login_username_r,
                              const char **error_r)
 {
        const char *login_username = NULL;
@@ -137,18 +213,10 @@ auth_scram_parse_client_first(struct scram_auth_request *server,
                        *error_r = "Username escaping is invalid";
                        return -1;
                }
-               if (!auth_request_set_username(&server->auth_request,
-                                              username, error_r))
-                       return -1;
        } else {
                *error_r = "Invalid username field";
                return -1;
        }
-       if (login_username != NULL) {
-               if (!auth_request_set_login_username(&server->auth_request,
-                                                    login_username, error_r))
-                       return -1;
-       }
 
        /* nonce           = "r=" c-nonce [s-nonce] */
        if (nonce[0] == 'r' && nonce[1] == '=')
@@ -158,6 +226,9 @@ auth_scram_parse_client_first(struct scram_auth_request *server,
                return -1;
        }
 
+       *username_r = username;
+       *login_username_r = login_username;
+
        server->gs2_header = p_strdup(server->pool, gs2_header);
        server->client_first_message_bare =
                p_strdup(server->pool, cfm_bare);
@@ -165,7 +236,7 @@ auth_scram_parse_client_first(struct scram_auth_request *server,
 }
 
 static string_t *
-auth_scram_get_server_first(struct scram_auth_request *server)
+auth_scram_get_server_first(struct auth_scram_server *server)
 {
        const struct hash_method *hmethod = server->hash_method;
        struct auth_scram_key_data *kdata = &server->key_data;
@@ -187,6 +258,7 @@ auth_scram_get_server_first(struct scram_auth_request *server)
                             ;; A positive number.
         */
 
+       i_assert(kdata->pool == server->pool);
        i_assert(kdata->hmethod == hmethod);
        i_assert(kdata->salt != NULL);
        i_assert(kdata->iter_count != 0);
@@ -206,11 +278,14 @@ auth_scram_get_server_first(struct scram_auth_request *server)
                        strlen(kdata->salt));
        str_printfa(str, "r=%s%s,s=%s,i=%d", server->cnonce, server->snonce,
                    kdata->salt, kdata->iter_count);
+
+       server->server_first_message = p_strdup(server->pool, str_c(str));
+
        return str;
 }
 
 static bool
-auth_scram_server_verify_credentials(struct scram_auth_request *server)
+auth_scram_server_verify_credentials(struct auth_scram_server *server)
 {
        const struct hash_method *hmethod = server->hash_method;
        struct auth_scram_key_data *kdata = &server->key_data;
@@ -221,6 +296,7 @@ auth_scram_server_verify_credentials(struct scram_auth_request *server)
        unsigned char stored_key[hmethod->digest_size];
        size_t i;
 
+       i_assert(kdata->pool == server->pool);
        i_assert(kdata->hmethod == hmethod);
 
        /* RFC 5802, Section 3:
@@ -255,7 +331,7 @@ auth_scram_server_verify_credentials(struct scram_auth_request *server)
 }
 
 static int
-auth_scram_parse_client_final(struct scram_auth_request *server,
+auth_scram_parse_client_final(struct auth_scram_server *server,
                              const unsigned char *data, size_t size,
                              const char **error_r)
 {
@@ -338,7 +414,7 @@ auth_scram_parse_client_final(struct scram_auth_request *server,
 }
 
 static string_t *
-auth_scram_get_server_final(struct scram_auth_request *server)
+auth_scram_get_server_final(struct auth_scram_server *server)
 {
        const struct hash_method *hmethod = server->hash_method;
        struct auth_scram_key_data *kdata = &server->key_data;
@@ -377,3 +453,206 @@ auth_scram_get_server_final(struct scram_auth_request *server)
 
        return str;
 }
+
+static int
+auth_scram_parse_client_finish(struct auth_scram_server *server ATTR_UNUSED,
+                              const unsigned char *data ATTR_UNUSED,
+                              size_t size, const char **error_r)
+{
+       if (size != 0) {
+               *error_r = "Spurious extra client message";
+               return -1;
+       }
+       return 0;
+}
+
+bool auth_scram_server_acces_granted(struct auth_scram_server *server)
+{
+       return (server->state == AUTH_SCRAM_SERVER_STATE_SERVER_FINAL);
+}
+
+static int
+auth_scram_server_input_client_first(struct auth_scram_server *server,
+                                    const unsigned char *input,
+                                    size_t input_len,
+                                    enum auth_scram_server_error *error_code_r,
+                                    const char **error_r)
+{
+       const char *username, *login_username;
+       int ret;
+
+       username = login_username = NULL;
+       
+       /* Parse client-first message */
+       ret = auth_scram_parse_client_first(server, input, input_len,
+                                           &username, &login_username,
+                                           error_r);
+       if (ret < 0) {
+               *error_code_r = AUTH_SCRAM_SERVER_ERROR_PROTOCOL_VIOLATION;
+               return -1;
+       }
+
+       /* Pass usernames to backend */
+       i_assert(username != NULL);
+       if (!auth_scram_server_set_username(server, username, error_r)) {
+               *error_code_r = AUTH_SCRAM_SERVER_ERROR_BAD_USERNAME;
+               return -1;
+       }
+       if (login_username != NULL &&
+           !auth_scram_server_set_login_username(server, login_username,
+                                                 error_r)) {
+               *error_code_r = AUTH_SCRAM_SERVER_ERROR_BAD_LOGIN_USERNAME;
+               return -1;
+       }
+       
+       return 0;
+}
+
+static int
+auth_scram_server_input_client_final(struct auth_scram_server *server,
+                                    const unsigned char *input,
+                                    size_t input_len,
+                                    enum auth_scram_server_error *error_code_r,
+                                    const char **error_r)
+{
+       int ret;
+       
+       /* Parse client-final message */
+       ret = auth_scram_parse_client_final(server, input, input_len, error_r);
+       if (ret < 0) {
+               *error_code_r = AUTH_SCRAM_SERVER_ERROR_PROTOCOL_VIOLATION;
+               return -1;
+       }
+
+       /* Verify client credentials */
+       if (!auth_scram_server_verify_credentials(server)) {
+               *error_code_r = AUTH_SCRAM_SERVER_ERROR_VERIFICATION_FAILED;
+               *error_r = "Password mismatch";
+               return -1;
+       }
+
+       return 0;
+}
+
+int auth_scram_server_input(struct auth_scram_server *server,
+                           const unsigned char *input, size_t input_len,
+                           enum auth_scram_server_error *error_code_r,
+                           const char **error_r)
+{
+       struct auth_scram_key_data *kdata = &server->key_data;
+       int ret = 0;
+
+       *error_code_r = AUTH_SCRAM_SERVER_ERROR_NONE;
+       *error_r = NULL;
+
+       switch (server->state) {
+       case AUTH_SCRAM_SERVER_STATE_INIT:
+               server->state = AUTH_SCRAM_SERVER_STATE_CLIENT_FIRST;
+               /* Fall through */
+       case AUTH_SCRAM_SERVER_STATE_CLIENT_FIRST:
+               /* Handle client-first message */
+               ret = auth_scram_server_input_client_first(
+                       server, input, input_len, error_code_r, error_r);
+               if (ret < 0) {
+                       server->state = AUTH_SCRAM_SERVER_STATE_ERROR;
+                       ret = -1;
+                       break;
+               }
+
+               /* Initiate credentials lookup */
+               server->state = AUTH_SCRAM_SERVER_STATE_CREDENTIALS_LOOKUP;
+               if (auth_scram_server_credentials_lookup(server) < 0) {
+                       *error_code_r = AUTH_SCRAM_SERVER_ERROR_LOOKUP_FAILED;
+                       *error_r = "Credentials lookup failed";
+                       server->state = AUTH_SCRAM_SERVER_STATE_ERROR;
+                       ret = -1;
+                       break;
+               }
+               if (server->state ==
+                   AUTH_SCRAM_SERVER_STATE_CREDENTIALS_LOOKUP) {
+                       server->state = AUTH_SCRAM_SERVER_STATE_SERVER_FIRST;
+                       ret = (kdata->salt != NULL ? 1 : 0);
+                       break;
+               }
+               i_assert(server->state >= AUTH_SCRAM_SERVER_STATE_SERVER_FIRST);
+               ret = 0;
+               break;
+       case AUTH_SCRAM_SERVER_STATE_CREDENTIALS_LOOKUP:
+       case AUTH_SCRAM_SERVER_STATE_SERVER_FIRST:
+               i_unreached();
+       case AUTH_SCRAM_SERVER_STATE_CLIENT_FINAL:
+               /* Handle client-final message */
+               ret = auth_scram_server_input_client_final(
+                       server, input, input_len, error_code_r, error_r);
+               if (ret < 0) {
+                       server->state = AUTH_SCRAM_SERVER_STATE_ERROR;
+                       break;
+               }
+               server->state = AUTH_SCRAM_SERVER_STATE_SERVER_FINAL;
+               ret = 1;
+               break;
+       case AUTH_SCRAM_SERVER_STATE_SERVER_FINAL:
+               i_unreached();
+       case AUTH_SCRAM_SERVER_STATE_CLIENT_FINISH:
+               server->state = AUTH_SCRAM_SERVER_STATE_END;
+               ret = auth_scram_parse_client_finish(server, input, input_len,
+                                                    error_r);
+               if (ret < 0) {
+                       *error_code_r =
+                               AUTH_SCRAM_SERVER_ERROR_PROTOCOL_VIOLATION;
+                       server->state = AUTH_SCRAM_SERVER_STATE_ERROR;
+               }
+               break;
+       case AUTH_SCRAM_SERVER_STATE_END:
+       case AUTH_SCRAM_SERVER_STATE_ERROR:
+               i_unreached();
+       }
+
+       return ret;
+}
+
+bool auth_scram_server_output(struct auth_scram_server *server,
+                             const unsigned char **output_r,
+                             size_t *output_len_r)
+{
+       struct auth_scram_key_data *kdata = &server->key_data;
+       string_t *output;
+       bool result = FALSE;
+
+       switch (server->state) {
+       case AUTH_SCRAM_SERVER_STATE_INIT:
+               *output_r = uchar_empty_ptr;
+               *output_len_r = 0;
+               server->state = AUTH_SCRAM_SERVER_STATE_CLIENT_FIRST;
+               break;
+       case AUTH_SCRAM_SERVER_STATE_CLIENT_FIRST:
+               i_unreached();
+       case AUTH_SCRAM_SERVER_STATE_CREDENTIALS_LOOKUP:
+               i_assert(kdata->salt != NULL);
+               server->state = AUTH_SCRAM_SERVER_STATE_SERVER_FIRST;
+               /* Fall through */
+       case AUTH_SCRAM_SERVER_STATE_SERVER_FIRST:
+               /* Compose server-first message */
+               output = auth_scram_get_server_first(server);
+               *output_r = str_data(output);
+               *output_len_r = str_len(output);
+               server->state = AUTH_SCRAM_SERVER_STATE_CLIENT_FINAL;
+               break;
+       case AUTH_SCRAM_SERVER_STATE_CLIENT_FINAL:
+               i_unreached();
+       case AUTH_SCRAM_SERVER_STATE_SERVER_FINAL:
+               /* Compose server-final message */
+               output = auth_scram_get_server_final(server);
+               *output_r = str_data(output);
+               *output_len_r = str_len(output);
+               server->state = AUTH_SCRAM_SERVER_STATE_CLIENT_FINISH;
+               result = TRUE;
+               break;
+       case AUTH_SCRAM_SERVER_STATE_CLIENT_FINISH:
+       case AUTH_SCRAM_SERVER_STATE_END:
+       case AUTH_SCRAM_SERVER_STATE_ERROR:
+               i_unreached();
+       }
+
+       return result;
+}
diff --git a/src/auth/auth-scram-server.h b/src/auth/auth-scram-server.h
new file mode 100644 (file)
index 0000000..94c416b
--- /dev/null
@@ -0,0 +1,101 @@
+#ifndef AUTH_SCRAM_SERVER_H
+#define AUTH_SCRAM_SERVER_H
+
+#include "auth-scram.h"
+
+struct auth_scram_server;
+
+enum auth_scram_server_error {
+       /* Success */
+       AUTH_SCRAM_SERVER_ERROR_NONE,
+       /* Protocol violation */
+       AUTH_SCRAM_SERVER_ERROR_PROTOCOL_VIOLATION,
+       /* Backend rejected the username provided by the client as invalid */
+       AUTH_SCRAM_SERVER_ERROR_BAD_USERNAME,
+       /* Something went wrong passing the login username to the backend */
+       AUTH_SCRAM_SERVER_ERROR_BAD_LOGIN_USERNAME,
+       /* Credentials lookup failed (nonexistent user or internal error). */
+       AUTH_SCRAM_SERVER_ERROR_LOOKUP_FAILED,
+       /* Credentials provided by client failed to verify against the
+          credentials looked up earlier. */
+       AUTH_SCRAM_SERVER_ERROR_VERIFICATION_FAILED,
+};
+
+enum auth_scram_server_state {
+       AUTH_SCRAM_SERVER_STATE_INIT = 0,
+       AUTH_SCRAM_SERVER_STATE_CLIENT_FIRST,
+       AUTH_SCRAM_SERVER_STATE_CREDENTIALS_LOOKUP,
+       AUTH_SCRAM_SERVER_STATE_SERVER_FIRST,
+       AUTH_SCRAM_SERVER_STATE_CLIENT_FINAL,
+       AUTH_SCRAM_SERVER_STATE_SERVER_FINAL,
+       AUTH_SCRAM_SERVER_STATE_CLIENT_FINISH,
+       AUTH_SCRAM_SERVER_STATE_END,
+       AUTH_SCRAM_SERVER_STATE_ERROR,
+};
+
+struct auth_scram_server_backend {
+       /* Pass the authentication and authorization usernames to the
+          backend. */
+       bool (*set_username)(struct auth_scram_server *server,
+                            const char *username, const char **error_r);
+       bool (*set_login_username)(struct auth_scram_server *server,
+                                  const char *username, const char **error_r);
+
+       /* Instruct the backend to perform credentials lookup. The acquired
+          credentials are to be assigned to the provided key_data struct
+          eventually. If not immediately, the backend is supposed to call
+          auth_scram_server_output() later once the key_data struct is
+          initialized (i.e. when the lookup concludes). */
+       int (*credentials_lookup)(struct auth_scram_server *server,
+                                 struct auth_scram_key_data *key_data);
+};
+
+struct auth_scram_server {
+       pool_t pool;
+       const struct hash_method *hash_method;
+
+       /* Backend API */
+       const struct auth_scram_server_backend *backend;
+       void *context;
+
+       enum auth_scram_server_state state;
+
+       /* Sent: */
+       const char *server_first_message;
+       const char *snonce;
+
+       /* Received: */
+       const char *gs2_header;
+       const char *cnonce;
+       const char *client_first_message_bare;
+       const char *client_final_message_without_proof;
+       buffer_t *proof;
+
+       /* Looked up: */
+       struct auth_scram_key_data key_data;
+};
+
+void auth_scram_server_init(struct auth_scram_server *server_r, pool_t pool,
+                           const struct hash_method *hmethod,
+                           const struct auth_scram_server_backend *backend);
+void auth_scram_server_deinit(struct auth_scram_server *server);
+
+/* Returns TRUE if authentication was concluded successfully. */
+bool auth_scram_server_acces_granted(struct auth_scram_server *server);
+
+/* Pass client input to the server. Returns 1 if server output is available, 0
+   if no server output is available yet (e.g. pending credentials lookup), and
+   -1 upon error (error_code_r and error_r are set accordingly). */
+int auth_scram_server_input(struct auth_scram_server *server,
+                           const unsigned char *input, size_t input_len,
+                           enum auth_scram_server_error *error_code_r,
+                           const char **error_r);
+/* Obtain output from server. This will assert fail if called out of sequence.
+   Returns TRUE if this is the last authentication step and success may be
+   indicated to the client or FALSE when the authentication handshake continues.
+ */
+bool auth_scram_server_output(struct auth_scram_server *server,
+                             const unsigned char **output_r,
+                             size_t *output_len_r);
+
+#endif
index 6956a4777c78fa042120c00824d85d21d989dadc..ea9194b287ac5306b1da87be67133fe9664f58b8 100644 (file)
@@ -2,6 +2,7 @@
 #define AUTH_SCRAM_H
 
 struct auth_scram_key_data {
+       pool_t pool;
        const struct hash_method *hmethod;
 
        unsigned int iter_count;
index eba0467bdbb8af6fd420baca7a42158d75f4f539..0c3a399ae0f311b7324604353429fc994183926d 100644 (file)
@@ -1,24 +1,9 @@
-/*
- * SCRAM-SHA-1 SASL authentication, see RFC-5802
- *
- * Copyright (c) 2011-2016 Florian Zeitz <florob@babelmonkeys.de>
- *
- * This software is released under the MIT license.
- */
+/* Copyright (c) 2011-2023 Dovecot authors, see the included COPYING file */
 
 #include "auth-common.h"
-#include "base64.h"
-#include "buffer.h"
-#include "hmac.h"
 #include "sha1.h"
 #include "sha2.h"
-#include "randgen.h"
-#include "safe-memset.h"
-#include "str.h"
-#include "strfuncs.h"
-#include "strnum.h"
-#include "password-scheme.h"
-#include "auth-scram.h"
+#include "auth-scram-server.h"
 #include "mech.h"
 #include "mech-scram.h"
 
@@ -29,27 +14,12 @@ struct scram_auth_request {
        struct auth_request auth_request;
 
        pool_t pool;
-
-       const struct hash_method *hash_method;
        const char *password_scheme;
 
-       /* sent: */
-       const char *server_first_message;
-       const char *snonce;
-
-       /* received: */
-       const char *gs2_header;
-       const char *cnonce;
-       const char *client_first_message_bare;
-       const char *client_final_message_without_proof;
-       buffer_t *proof;
-
-       /* looked up: */
-       struct auth_scram_key_data key_data;
+       struct auth_scram_server scram_server;
+       struct auth_scram_key_data *key_data;
 };
 
-#include "auth-scram-server.c"
-
 static void
 credentials_callback(enum passdb_result result,
                     const unsigned char *credentials, size_t size,
@@ -58,8 +28,11 @@ credentials_callback(enum passdb_result result,
        struct scram_auth_request *request =
                container_of(auth_request, struct scram_auth_request,
                             auth_request);
-       struct auth_scram_key_data *key_data = &request->key_data;
+       struct auth_scram_key_data *key_data = request->key_data;
        const char *error;
+       const unsigned char *output;
+       size_t output_len;
+       bool end;
 
        switch (result) {
        case PASSDB_RESULT_OK:
@@ -76,12 +49,11 @@ credentials_callback(enum passdb_result result,
                        break;
                }
 
-               request->server_first_message = p_strdup(request->pool,
-                       str_c(auth_scram_get_server_first(request)));
-
+               end = auth_scram_server_output(&request->scram_server,
+                                              &output, &output_len);
+               i_assert(!end);
                auth_request_handler_reply_continue(auth_request,
-                                       request->server_first_message,
-                                       strlen(request->server_first_message));
+                                                   output, output_len);
                break;
        case PASSDB_RESULT_INTERNAL_FAILURE:
                auth_request_internal_failure(auth_request);
@@ -92,47 +64,85 @@ credentials_callback(enum passdb_result result,
        }
 }
 
+static bool
+mech_scram_set_username(struct auth_scram_server *asserver,
+                       const char *username, const char **error_r)
+{
+       struct scram_auth_request *request =
+               container_of(asserver, struct scram_auth_request, scram_server);
+
+       return auth_request_set_username(&request->auth_request,
+                                        username, error_r);
+}
+
+static bool
+mech_scram_set_login_username(struct auth_scram_server *asserver,
+                             const char *username, const char **error_r)
+{
+       struct scram_auth_request *request =
+               container_of(asserver, struct scram_auth_request, scram_server);
+
+       return auth_request_set_login_username(&request->auth_request,
+                                              username, error_r);
+}
+
+static int
+mech_scram_credentials_lookup(struct auth_scram_server *asserver,
+                             struct auth_scram_key_data *key_data)
+{
+       struct scram_auth_request *request =
+               container_of(asserver, struct scram_auth_request, scram_server);
+
+       request->key_data = key_data;
+       auth_request_lookup_credentials(&request->auth_request,
+                                       request->password_scheme,
+                                       credentials_callback);
+       return 0;
+}
+
+static const struct auth_scram_server_backend scram_server_backend = {
+       .set_username = mech_scram_set_username,
+       .set_login_username = mech_scram_set_login_username,
+
+       .credentials_lookup = mech_scram_credentials_lookup,
+};
+
 void mech_scram_auth_continue(struct auth_request *auth_request,
-                             const unsigned char *data, size_t data_size)
+                             const unsigned char *input, size_t input_len)
 {
        struct scram_auth_request *request =
                container_of(auth_request, struct scram_auth_request,
                             auth_request);
+       enum auth_scram_server_error error_code;
        const char *error = NULL;
-       const char *server_final_message;
-       size_t len;
-
-       if (request->client_first_message_bare == NULL) {
-               /* Received client-first-message */
-               if (auth_scram_parse_client_first(request, data,
-                                                 data_size, &error) >= 0) {
-                       auth_request_lookup_credentials(
-                               &request->auth_request,
-                               request->password_scheme,
-                               credentials_callback);
-                       return;
-               }
-       } else {
-               /* Received client-final-message */
-               if (auth_scram_parse_client_final(request, data, data_size,
-                                                 &error) >= 0) {
-                       if (!auth_scram_server_verify_credentials(request)) {
-                               e_info(auth_request->mech_event,
-                                      AUTH_LOG_MSG_PASSWORD_MISMATCH);
-                       } else {
-                               server_final_message =
-                                       str_c(auth_scram_get_server_final(request));
-                               len = strlen(server_final_message);
-                               auth_request_success(auth_request,
-                                                    server_final_message, len);
-                               return;
-                       }
+       const unsigned char *output;
+       size_t output_len;
+       int ret;
+
+       ret = auth_scram_server_input(&request->scram_server, input, input_len,
+                                     &error_code, &error);
+       if (ret < 0) {
+               i_assert(error != NULL);
+               if (error_code == AUTH_SCRAM_SERVER_ERROR_VERIFICATION_FAILED) {
+                       e_info(auth_request->mech_event,
+                              AUTH_LOG_MSG_PASSWORD_MISMATCH);
+               } else {
+                       e_info(auth_request->mech_event, "%s", error);
                }
+               auth_request_fail(auth_request);
+               return;
        }
+       if (ret == 0)
+               return;
 
-       if (error != NULL)
-               e_info(auth_request->mech_event, "%s", error);
-       auth_request_fail(auth_request);
+       if (!auth_scram_server_output(&request->scram_server,
+                                     &output, &output_len)) {
+               auth_request_handler_reply_continue(auth_request,
+                                                   output, output_len);
+               return;
+       }
+
+       auth_request_success(auth_request, output, output_len);
 }
 
 struct auth_request *
@@ -145,14 +155,10 @@ mech_scram_auth_new(const struct hash_method *hash_method,
        pool = pool_alloconly_create(MEMPOOL_GROWING"scram_auth_request", 2048);
        request = p_new(pool, struct scram_auth_request, 1);
        request->pool = pool;
-
-       request->hash_method = hash_method;
        request->password_scheme = password_scheme;
 
-       i_zero(&request->key_data);
-       request->key_data.hmethod = hash_method;
-       request->key_data.stored_key = p_malloc(pool, hash_method->digest_size);
-       request->key_data.server_key = p_malloc(pool, hash_method->digest_size);
+       auth_scram_server_init(&request->scram_server, pool,
+                              hash_method, &scram_server_backend);
 
        request->auth_request.pool = pool;
        return &request->auth_request;
@@ -170,6 +176,11 @@ static struct auth_request *mech_scram_sha256_auth_new(void)
 
 static void mech_scram_auth_free(struct auth_request *auth_request)
 {
+       struct scram_auth_request *request =
+               container_of(auth_request, struct scram_auth_request,
+                            auth_request);
+
+       auth_scram_server_deinit(&request->scram_server);
        pool_unref(&auth_request->pool);
 }