+import base64
import json
+from unittest import mock
from allauth.mfa.models import Authenticator
from django.contrib.auth.models import Group
self.assertNotIn("user_can_change", results[0])
self.assertNotIn("is_shared_by_requester", results[0])
+ @mock.patch("allauth.mfa.adapter.DefaultMFAAdapter.is_mfa_enabled")
+ def test_basic_auth_mfa_enabled(self, mock_is_mfa_enabled):
+ """
+ GIVEN:
+ - User with MFA enabled
+ WHEN:
+ - API request is made with basic auth
+ THEN:
+ - MFA required error is returned
+ """
+ user1 = User.objects.create_user(username="user1")
+ user1.set_password("password")
+ user1.save()
+
+ mock_is_mfa_enabled.return_value = True
+
+ response = self.client.get(
+ "/api/documents/",
+ HTTP_AUTHORIZATION="Basic " + base64.b64encode(b"user1:password").decode(),
+ )
+
+ self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
+ self.assertEqual(response.data["detail"], "MFA required")
+
class TestApiUser(DirectoriesMixin, APITestCase):
ENDPOINT = "/api/users/"
import logging
+from allauth.mfa.adapter import get_adapter as get_mfa_adapter
from django.conf import settings
from django.contrib import auth
from django.contrib.auth.middleware import PersistentRemoteUserMiddleware
from django.http import HttpRequest
from django.utils.deprecation import MiddlewareMixin
from rest_framework import authentication
+from rest_framework import exceptions
logger = logging.getLogger("paperless.auth")
"""
header = settings.HTTP_REMOTE_USER_HEADER_NAME
+
+
+class PaperlessBasicAuthentication(authentication.BasicAuthentication):
+ def authenticate(self, request):
+ user_tuple = super().authenticate(request)
+ user = user_tuple[0] if user_tuple else None
+ mfa_adapter = get_mfa_adapter()
+ if user and mfa_adapter.is_mfa_enabled(user):
+ raise exceptions.AuthenticationFailed("MFA required")
+
+ return user_tuple
REST_FRAMEWORK = {
"DEFAULT_AUTHENTICATION_CLASSES": [
- "rest_framework.authentication.BasicAuthentication",
+ "paperless.auth.PaperlessBasicAuthentication",
"rest_framework.authentication.TokenAuthentication",
"rest_framework.authentication.SessionAuthentication",
],