]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Fix: disable API basic auth if MFA enabled (#8792)
authorshamoon <4887959+shamoon@users.noreply.github.com>
Sat, 18 Jan 2025 03:51:53 +0000 (19:51 -0800)
committerGitHub <noreply@github.com>
Sat, 18 Jan 2025 03:51:53 +0000 (03:51 +0000)
src/documents/tests/test_api_permissions.py
src/paperless/auth.py
src/paperless/settings.py

index eeea830cbab42a068be15dbc28f53ae3ffb6ca18..ef50c55f71cab3447698dc4d4c0587487ee31462 100644 (file)
@@ -1,4 +1,6 @@
+import base64
 import json
+from unittest import mock
 
 from allauth.mfa.models import Authenticator
 from django.contrib.auth.models import Group
@@ -462,6 +464,30 @@ class TestApiAuth(DirectoriesMixin, APITestCase):
         self.assertNotIn("user_can_change", results[0])
         self.assertNotIn("is_shared_by_requester", results[0])
 
+    @mock.patch("allauth.mfa.adapter.DefaultMFAAdapter.is_mfa_enabled")
+    def test_basic_auth_mfa_enabled(self, mock_is_mfa_enabled):
+        """
+        GIVEN:
+            - User with MFA enabled
+        WHEN:
+            - API request is made with basic auth
+        THEN:
+            - MFA required error is returned
+        """
+        user1 = User.objects.create_user(username="user1")
+        user1.set_password("password")
+        user1.save()
+
+        mock_is_mfa_enabled.return_value = True
+
+        response = self.client.get(
+            "/api/documents/",
+            HTTP_AUTHORIZATION="Basic " + base64.b64encode(b"user1:password").decode(),
+        )
+
+        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
+        self.assertEqual(response.data["detail"], "MFA required")
+
 
 class TestApiUser(DirectoriesMixin, APITestCase):
     ENDPOINT = "/api/users/"
index 6ca97d608141aeef8b743d25bb2f22fd79b40348..36131847fbe3eba047c5006ac3000168124bf6b7 100644 (file)
@@ -1,5 +1,6 @@
 import logging
 
+from allauth.mfa.adapter import get_adapter as get_mfa_adapter
 from django.conf import settings
 from django.contrib import auth
 from django.contrib.auth.middleware import PersistentRemoteUserMiddleware
@@ -7,6 +8,7 @@ from django.contrib.auth.models import User
 from django.http import HttpRequest
 from django.utils.deprecation import MiddlewareMixin
 from rest_framework import authentication
+from rest_framework import exceptions
 
 logger = logging.getLogger("paperless.auth")
 
@@ -70,3 +72,14 @@ class PaperlessRemoteUserAuthentication(authentication.RemoteUserAuthentication)
     """
 
     header = settings.HTTP_REMOTE_USER_HEADER_NAME
+
+
+class PaperlessBasicAuthentication(authentication.BasicAuthentication):
+    def authenticate(self, request):
+        user_tuple = super().authenticate(request)
+        user = user_tuple[0] if user_tuple else None
+        mfa_adapter = get_mfa_adapter()
+        if user and mfa_adapter.is_mfa_enabled(user):
+            raise exceptions.AuthenticationFailed("MFA required")
+
+        return user_tuple
index 3fc9bfdbf83a6ef428a97d51b4a3a621e63d6318..ef842dde6ab7b1d5c02f6691ce7488e610275a9f 100644 (file)
@@ -336,7 +336,7 @@ if DEBUG:
 
 REST_FRAMEWORK = {
     "DEFAULT_AUTHENTICATION_CLASSES": [
-        "rest_framework.authentication.BasicAuthentication",
+        "paperless.auth.PaperlessBasicAuthentication",
         "rest_framework.authentication.TokenAuthentication",
         "rest_framework.authentication.SessionAuthentication",
     ],