from unittest import mock
from allauth.mfa.models import Authenticator
+from allauth.mfa.totp.internal import auth as totp_auth
from django.contrib.auth.models import Group
from django.contrib.auth.models import Permission
from django.contrib.auth.models import User
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response.data["detail"], "MFA required")
+ @mock.patch("allauth.mfa.totp.internal.auth.TOTP.validate_code")
+ def test_get_token_mfa_enabled(self, mock_validate_code):
+ """
+ GIVEN:
+ - User with MFA enabled
+ WHEN:
+ - API request is made to obtain an auth token
+ THEN:
+ - MFA code is required
+ """
+ user1 = User.objects.create_user(username="user1")
+ user1.set_password("password")
+ user1.save()
+
+ response = self.client.post(
+ "/api/token/",
+ data={
+ "username": "user1",
+ "password": "password",
+ },
+ )
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ secret = totp_auth.generate_totp_secret()
+ totp_auth.TOTP.activate(
+ user1,
+ secret,
+ )
+
+ # no code
+ response = self.client.post(
+ "/api/token/",
+ data={
+ "username": "user1",
+ "password": "password",
+ },
+ )
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(response.data["non_field_errors"][0], "MFA code is required")
+
+ # invalid code
+ mock_validate_code.return_value = False
+ response = self.client.post(
+ "/api/token/",
+ data={
+ "username": "user1",
+ "password": "password",
+ "code": "123456",
+ },
+ )
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(response.data["non_field_errors"][0], "Invalid MFA code")
+
+ # valid code
+ mock_validate_code.return_value = True
+ response = self.client.post(
+ "/api/token/",
+ data={
+ "username": "user1",
+ "password": "password",
+ "code": "123456",
+ },
+ )
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
class TestApiUser(DirectoriesMixin, APITestCase):
ENDPOINT = "/api/users/"
import logging
from allauth.mfa.adapter import get_adapter as get_mfa_adapter
+from allauth.mfa.models import Authenticator
+from allauth.mfa.totp.internal.auth import TOTP
from allauth.socialaccount.models import SocialAccount
from django.contrib.auth.models import Group
from django.contrib.auth.models import Permission
from django.contrib.auth.models import User
from rest_framework import serializers
+from rest_framework.authtoken.serializers import AuthTokenSerializer
from paperless.models import ApplicationConfiguration
return data
+class PaperlessAuthTokenSerializer(AuthTokenSerializer):
+ code = serializers.CharField(
+ label="MFA Code",
+ write_only=True,
+ required=False,
+ )
+
+ def validate(self, attrs):
+ attrs = super().validate(attrs)
+ user = attrs.get("user")
+ code = attrs.get("code")
+ mfa_adapter = get_mfa_adapter()
+ if mfa_adapter.is_mfa_enabled(user):
+ if not code:
+ raise serializers.ValidationError(
+ "MFA code is required",
+ )
+ authenticator = Authenticator.objects.get(
+ user=user,
+ type=Authenticator.Type.TOTP,
+ )
+ if not TOTP(instance=authenticator).validate_code(
+ code,
+ ):
+ raise serializers.ValidationError(
+ "Invalid MFA code",
+ )
+ return attrs
+
+
class UserSerializer(serializers.ModelSerializer):
password = ObfuscatedUserPasswordField(required=False)
user_permissions = serializers.SlugRelatedField(
from django.views.decorators.csrf import ensure_csrf_cookie
from django.views.generic import RedirectView
from django.views.static import serve
-from rest_framework.authtoken import views
from rest_framework.routers import DefaultRouter
from documents.views import BulkDownloadView
from paperless.views import FaviconView
from paperless.views import GenerateAuthTokenView
from paperless.views import GroupViewSet
+from paperless.views import PaperlessObtainAuthTokenView
from paperless.views import ProfileView
from paperless.views import SocialAccountProvidersView
from paperless.views import TOTPView
),
path(
"token/",
- views.obtain_auth_token,
+ PaperlessObtainAuthTokenView.as_view(),
),
re_path(
"^profile/",
from django.views.generic import View
from django_filters.rest_framework import DjangoFilterBackend
from rest_framework.authtoken.models import Token
+from rest_framework.authtoken.views import ObtainAuthToken
from rest_framework.decorators import action
from rest_framework.filters import OrderingFilter
from rest_framework.generics import GenericAPIView
from paperless.models import ApplicationConfiguration
from paperless.serialisers import ApplicationConfigurationSerializer
from paperless.serialisers import GroupSerializer
+from paperless.serialisers import PaperlessAuthTokenSerializer
from paperless.serialisers import ProfileSerializer
from paperless.serialisers import UserSerializer
+class PaperlessObtainAuthTokenView(ObtainAuthToken):
+ serializer_class = PaperlessAuthTokenSerializer
+
+
class StandardPagination(PageNumberPagination):
page_size = 25
page_size_query_param = "page_size"