]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Enhancement: require totp code for obtain auth token (#8936)
authorshamoon <4887959+shamoon@users.noreply.github.com>
Wed, 29 Jan 2025 15:23:44 +0000 (07:23 -0800)
committershamoon <4887959+shamoon@users.noreply.github.com>
Fri, 31 Jan 2025 15:44:47 +0000 (07:44 -0800)
src/documents/tests/test_api_permissions.py
src/paperless/serialisers.py
src/paperless/urls.py
src/paperless/views.py

index 5de1887b294efbcb43456f48bfb03135b561d6de..3785c8f2a6cc9369142dfa133f3e9ee72c30ffa1 100644 (file)
@@ -3,6 +3,7 @@ import json
 from unittest import mock
 
 from allauth.mfa.models import Authenticator
+from allauth.mfa.totp.internal import auth as totp_auth
 from django.contrib.auth.models import Group
 from django.contrib.auth.models import Permission
 from django.contrib.auth.models import User
@@ -488,6 +489,71 @@ class TestApiAuth(DirectoriesMixin, APITestCase):
         self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
         self.assertEqual(response.data["detail"], "MFA required")
 
+    @mock.patch("allauth.mfa.totp.internal.auth.TOTP.validate_code")
+    def test_get_token_mfa_enabled(self, mock_validate_code):
+        """
+        GIVEN:
+            - User with MFA enabled
+        WHEN:
+            - API request is made to obtain an auth token
+        THEN:
+            - MFA code is required
+        """
+        user1 = User.objects.create_user(username="user1")
+        user1.set_password("password")
+        user1.save()
+
+        response = self.client.post(
+            "/api/token/",
+            data={
+                "username": "user1",
+                "password": "password",
+            },
+        )
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+        secret = totp_auth.generate_totp_secret()
+        totp_auth.TOTP.activate(
+            user1,
+            secret,
+        )
+
+        # no code
+        response = self.client.post(
+            "/api/token/",
+            data={
+                "username": "user1",
+                "password": "password",
+            },
+        )
+        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertEqual(response.data["non_field_errors"][0], "MFA code is required")
+
+        # invalid code
+        mock_validate_code.return_value = False
+        response = self.client.post(
+            "/api/token/",
+            data={
+                "username": "user1",
+                "password": "password",
+                "code": "123456",
+            },
+        )
+        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertEqual(response.data["non_field_errors"][0], "Invalid MFA code")
+
+        # valid code
+        mock_validate_code.return_value = True
+        response = self.client.post(
+            "/api/token/",
+            data={
+                "username": "user1",
+                "password": "password",
+                "code": "123456",
+            },
+        )
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+
 
 class TestApiUser(DirectoriesMixin, APITestCase):
     ENDPOINT = "/api/users/"
index d5acfe465a86ea24b0fe1ea3732950eeacf7994f..fb1f511f744dc7110d7151ead9ac9194cabc1826 100644 (file)
@@ -1,11 +1,14 @@
 import logging
 
 from allauth.mfa.adapter import get_adapter as get_mfa_adapter
+from allauth.mfa.models import Authenticator
+from allauth.mfa.totp.internal.auth import TOTP
 from allauth.socialaccount.models import SocialAccount
 from django.contrib.auth.models import Group
 from django.contrib.auth.models import Permission
 from django.contrib.auth.models import User
 from rest_framework import serializers
+from rest_framework.authtoken.serializers import AuthTokenSerializer
 
 from paperless.models import ApplicationConfiguration
 
@@ -24,6 +27,36 @@ class ObfuscatedUserPasswordField(serializers.Field):
         return data
 
 
+class PaperlessAuthTokenSerializer(AuthTokenSerializer):
+    code = serializers.CharField(
+        label="MFA Code",
+        write_only=True,
+        required=False,
+    )
+
+    def validate(self, attrs):
+        attrs = super().validate(attrs)
+        user = attrs.get("user")
+        code = attrs.get("code")
+        mfa_adapter = get_mfa_adapter()
+        if mfa_adapter.is_mfa_enabled(user):
+            if not code:
+                raise serializers.ValidationError(
+                    "MFA code is required",
+                )
+            authenticator = Authenticator.objects.get(
+                user=user,
+                type=Authenticator.Type.TOTP,
+            )
+            if not TOTP(instance=authenticator).validate_code(
+                code,
+            ):
+                raise serializers.ValidationError(
+                    "Invalid MFA code",
+                )
+        return attrs
+
+
 class UserSerializer(serializers.ModelSerializer):
     password = ObfuscatedUserPasswordField(required=False)
     user_permissions = serializers.SlugRelatedField(
index c528c5e2a2d665c5b1558ff8c4d6886f853c360d..703a72042af85245bca5bf58ce96c59dc9385901 100644 (file)
@@ -14,7 +14,6 @@ from django.utils.translation import gettext_lazy as _
 from django.views.decorators.csrf import ensure_csrf_cookie
 from django.views.generic import RedirectView
 from django.views.static import serve
-from rest_framework.authtoken import views
 from rest_framework.routers import DefaultRouter
 
 from documents.views import BulkDownloadView
@@ -50,6 +49,7 @@ from paperless.views import DisconnectSocialAccountView
 from paperless.views import FaviconView
 from paperless.views import GenerateAuthTokenView
 from paperless.views import GroupViewSet
+from paperless.views import PaperlessObtainAuthTokenView
 from paperless.views import ProfileView
 from paperless.views import SocialAccountProvidersView
 from paperless.views import TOTPView
@@ -157,7 +157,7 @@ urlpatterns = [
                 ),
                 path(
                     "token/",
-                    views.obtain_auth_token,
+                    PaperlessObtainAuthTokenView.as_view(),
                 ),
                 re_path(
                     "^profile/",
index 03721adf2c03cbf3b1ffad0bc31f66fe63c0e826..bcabd182f0c0bfcc7e70fdb6486e931e9bbc6785 100644 (file)
@@ -19,6 +19,7 @@ from django.http import HttpResponseNotFound
 from django.views.generic import View
 from django_filters.rest_framework import DjangoFilterBackend
 from rest_framework.authtoken.models import Token
+from rest_framework.authtoken.views import ObtainAuthToken
 from rest_framework.decorators import action
 from rest_framework.filters import OrderingFilter
 from rest_framework.generics import GenericAPIView
@@ -35,10 +36,15 @@ from paperless.filters import UserFilterSet
 from paperless.models import ApplicationConfiguration
 from paperless.serialisers import ApplicationConfigurationSerializer
 from paperless.serialisers import GroupSerializer
+from paperless.serialisers import PaperlessAuthTokenSerializer
 from paperless.serialisers import ProfileSerializer
 from paperless.serialisers import UserSerializer
 
 
+class PaperlessObtainAuthTokenView(ObtainAuthToken):
+    serializer_class = PaperlessAuthTokenSerializer
+
+
 class StandardPagination(PageNumberPagination):
     page_size = 25
     page_size_query_param = "page_size"