mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-11-03 19:17:13 -05:00 
			
		
		
		
	Enhancement: require totp code for obtain auth token (#8936)
This commit is contained in:
		
							parent
							
								
									978b072bff
								
							
						
					
					
						commit
						79956d6a7b
					
				@ -3,6 +3,7 @@ import json
 | 
				
			|||||||
from unittest import mock
 | 
					from unittest import mock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from allauth.mfa.models import Authenticator
 | 
					from allauth.mfa.models import Authenticator
 | 
				
			||||||
 | 
					from allauth.mfa.totp.internal import auth as totp_auth
 | 
				
			||||||
from django.contrib.auth.models import Group
 | 
					from django.contrib.auth.models import Group
 | 
				
			||||||
from django.contrib.auth.models import Permission
 | 
					from django.contrib.auth.models import Permission
 | 
				
			||||||
from django.contrib.auth.models import User
 | 
					from django.contrib.auth.models import User
 | 
				
			||||||
@ -488,6 +489,71 @@ class TestApiAuth(DirectoriesMixin, APITestCase):
 | 
				
			|||||||
        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
 | 
					        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
 | 
				
			||||||
        self.assertEqual(response.data["detail"], "MFA required")
 | 
					        self.assertEqual(response.data["detail"], "MFA required")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @mock.patch("allauth.mfa.totp.internal.auth.TOTP.validate_code")
 | 
				
			||||||
 | 
					    def test_get_token_mfa_enabled(self, mock_validate_code):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        GIVEN:
 | 
				
			||||||
 | 
					            - User with MFA enabled
 | 
				
			||||||
 | 
					        WHEN:
 | 
				
			||||||
 | 
					            - API request is made to obtain an auth token
 | 
				
			||||||
 | 
					        THEN:
 | 
				
			||||||
 | 
					            - MFA code is required
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        user1 = User.objects.create_user(username="user1")
 | 
				
			||||||
 | 
					        user1.set_password("password")
 | 
				
			||||||
 | 
					        user1.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        response = self.client.post(
 | 
				
			||||||
 | 
					            "/api/token/",
 | 
				
			||||||
 | 
					            data={
 | 
				
			||||||
 | 
					                "username": "user1",
 | 
				
			||||||
 | 
					                "password": "password",
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, status.HTTP_200_OK)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        secret = totp_auth.generate_totp_secret()
 | 
				
			||||||
 | 
					        totp_auth.TOTP.activate(
 | 
				
			||||||
 | 
					            user1,
 | 
				
			||||||
 | 
					            secret,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # no code
 | 
				
			||||||
 | 
					        response = self.client.post(
 | 
				
			||||||
 | 
					            "/api/token/",
 | 
				
			||||||
 | 
					            data={
 | 
				
			||||||
 | 
					                "username": "user1",
 | 
				
			||||||
 | 
					                "password": "password",
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
 | 
				
			||||||
 | 
					        self.assertEqual(response.data["non_field_errors"][0], "MFA code is required")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # invalid code
 | 
				
			||||||
 | 
					        mock_validate_code.return_value = False
 | 
				
			||||||
 | 
					        response = self.client.post(
 | 
				
			||||||
 | 
					            "/api/token/",
 | 
				
			||||||
 | 
					            data={
 | 
				
			||||||
 | 
					                "username": "user1",
 | 
				
			||||||
 | 
					                "password": "password",
 | 
				
			||||||
 | 
					                "code": "123456",
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
 | 
				
			||||||
 | 
					        self.assertEqual(response.data["non_field_errors"][0], "Invalid MFA code")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # valid code
 | 
				
			||||||
 | 
					        mock_validate_code.return_value = True
 | 
				
			||||||
 | 
					        response = self.client.post(
 | 
				
			||||||
 | 
					            "/api/token/",
 | 
				
			||||||
 | 
					            data={
 | 
				
			||||||
 | 
					                "username": "user1",
 | 
				
			||||||
 | 
					                "password": "password",
 | 
				
			||||||
 | 
					                "code": "123456",
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, status.HTTP_200_OK)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestApiUser(DirectoriesMixin, APITestCase):
 | 
					class TestApiUser(DirectoriesMixin, APITestCase):
 | 
				
			||||||
    ENDPOINT = "/api/users/"
 | 
					    ENDPOINT = "/api/users/"
 | 
				
			||||||
 | 
				
			|||||||
@ -1,11 +1,14 @@
 | 
				
			|||||||
import logging
 | 
					import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from allauth.mfa.adapter import get_adapter as get_mfa_adapter
 | 
					from allauth.mfa.adapter import get_adapter as get_mfa_adapter
 | 
				
			||||||
 | 
					from allauth.mfa.models import Authenticator
 | 
				
			||||||
 | 
					from allauth.mfa.totp.internal.auth import TOTP
 | 
				
			||||||
from allauth.socialaccount.models import SocialAccount
 | 
					from allauth.socialaccount.models import SocialAccount
 | 
				
			||||||
from django.contrib.auth.models import Group
 | 
					from django.contrib.auth.models import Group
 | 
				
			||||||
from django.contrib.auth.models import Permission
 | 
					from django.contrib.auth.models import Permission
 | 
				
			||||||
from django.contrib.auth.models import User
 | 
					from django.contrib.auth.models import User
 | 
				
			||||||
from rest_framework import serializers
 | 
					from rest_framework import serializers
 | 
				
			||||||
 | 
					from rest_framework.authtoken.serializers import AuthTokenSerializer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from paperless.models import ApplicationConfiguration
 | 
					from paperless.models import ApplicationConfiguration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -24,6 +27,36 @@ class ObfuscatedUserPasswordField(serializers.Field):
 | 
				
			|||||||
        return data
 | 
					        return data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PaperlessAuthTokenSerializer(AuthTokenSerializer):
 | 
				
			||||||
 | 
					    code = serializers.CharField(
 | 
				
			||||||
 | 
					        label="MFA Code",
 | 
				
			||||||
 | 
					        write_only=True,
 | 
				
			||||||
 | 
					        required=False,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def validate(self, attrs):
 | 
				
			||||||
 | 
					        attrs = super().validate(attrs)
 | 
				
			||||||
 | 
					        user = attrs.get("user")
 | 
				
			||||||
 | 
					        code = attrs.get("code")
 | 
				
			||||||
 | 
					        mfa_adapter = get_mfa_adapter()
 | 
				
			||||||
 | 
					        if mfa_adapter.is_mfa_enabled(user):
 | 
				
			||||||
 | 
					            if not code:
 | 
				
			||||||
 | 
					                raise serializers.ValidationError(
 | 
				
			||||||
 | 
					                    "MFA code is required",
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            authenticator = Authenticator.objects.get(
 | 
				
			||||||
 | 
					                user=user,
 | 
				
			||||||
 | 
					                type=Authenticator.Type.TOTP,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            if not TOTP(instance=authenticator).validate_code(
 | 
				
			||||||
 | 
					                code,
 | 
				
			||||||
 | 
					            ):
 | 
				
			||||||
 | 
					                raise serializers.ValidationError(
 | 
				
			||||||
 | 
					                    "Invalid MFA code",
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					        return attrs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class UserSerializer(serializers.ModelSerializer):
 | 
					class UserSerializer(serializers.ModelSerializer):
 | 
				
			||||||
    password = ObfuscatedUserPasswordField(required=False)
 | 
					    password = ObfuscatedUserPasswordField(required=False)
 | 
				
			||||||
    user_permissions = serializers.SlugRelatedField(
 | 
					    user_permissions = serializers.SlugRelatedField(
 | 
				
			||||||
 | 
				
			|||||||
@ -14,7 +14,6 @@ from django.utils.translation import gettext_lazy as _
 | 
				
			|||||||
from django.views.decorators.csrf import ensure_csrf_cookie
 | 
					from django.views.decorators.csrf import ensure_csrf_cookie
 | 
				
			||||||
from django.views.generic import RedirectView
 | 
					from django.views.generic import RedirectView
 | 
				
			||||||
from django.views.static import serve
 | 
					from django.views.static import serve
 | 
				
			||||||
from rest_framework.authtoken import views
 | 
					 | 
				
			||||||
from rest_framework.routers import DefaultRouter
 | 
					from rest_framework.routers import DefaultRouter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from documents.views import BulkDownloadView
 | 
					from documents.views import BulkDownloadView
 | 
				
			||||||
@ -50,6 +49,7 @@ from paperless.views import DisconnectSocialAccountView
 | 
				
			|||||||
from paperless.views import FaviconView
 | 
					from paperless.views import FaviconView
 | 
				
			||||||
from paperless.views import GenerateAuthTokenView
 | 
					from paperless.views import GenerateAuthTokenView
 | 
				
			||||||
from paperless.views import GroupViewSet
 | 
					from paperless.views import GroupViewSet
 | 
				
			||||||
 | 
					from paperless.views import PaperlessObtainAuthTokenView
 | 
				
			||||||
from paperless.views import ProfileView
 | 
					from paperless.views import ProfileView
 | 
				
			||||||
from paperless.views import SocialAccountProvidersView
 | 
					from paperless.views import SocialAccountProvidersView
 | 
				
			||||||
from paperless.views import TOTPView
 | 
					from paperless.views import TOTPView
 | 
				
			||||||
@ -157,7 +157,7 @@ urlpatterns = [
 | 
				
			|||||||
                ),
 | 
					                ),
 | 
				
			||||||
                path(
 | 
					                path(
 | 
				
			||||||
                    "token/",
 | 
					                    "token/",
 | 
				
			||||||
                    views.obtain_auth_token,
 | 
					                    PaperlessObtainAuthTokenView.as_view(),
 | 
				
			||||||
                ),
 | 
					                ),
 | 
				
			||||||
                re_path(
 | 
					                re_path(
 | 
				
			||||||
                    "^profile/",
 | 
					                    "^profile/",
 | 
				
			||||||
 | 
				
			|||||||
@ -19,6 +19,7 @@ from django.http import HttpResponseNotFound
 | 
				
			|||||||
from django.views.generic import View
 | 
					from django.views.generic import View
 | 
				
			||||||
from django_filters.rest_framework import DjangoFilterBackend
 | 
					from django_filters.rest_framework import DjangoFilterBackend
 | 
				
			||||||
from rest_framework.authtoken.models import Token
 | 
					from rest_framework.authtoken.models import Token
 | 
				
			||||||
 | 
					from rest_framework.authtoken.views import ObtainAuthToken
 | 
				
			||||||
from rest_framework.decorators import action
 | 
					from rest_framework.decorators import action
 | 
				
			||||||
from rest_framework.filters import OrderingFilter
 | 
					from rest_framework.filters import OrderingFilter
 | 
				
			||||||
from rest_framework.generics import GenericAPIView
 | 
					from rest_framework.generics import GenericAPIView
 | 
				
			||||||
@ -35,10 +36,15 @@ from paperless.filters import UserFilterSet
 | 
				
			|||||||
from paperless.models import ApplicationConfiguration
 | 
					from paperless.models import ApplicationConfiguration
 | 
				
			||||||
from paperless.serialisers import ApplicationConfigurationSerializer
 | 
					from paperless.serialisers import ApplicationConfigurationSerializer
 | 
				
			||||||
from paperless.serialisers import GroupSerializer
 | 
					from paperless.serialisers import GroupSerializer
 | 
				
			||||||
 | 
					from paperless.serialisers import PaperlessAuthTokenSerializer
 | 
				
			||||||
from paperless.serialisers import ProfileSerializer
 | 
					from paperless.serialisers import ProfileSerializer
 | 
				
			||||||
from paperless.serialisers import UserSerializer
 | 
					from paperless.serialisers import UserSerializer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PaperlessObtainAuthTokenView(ObtainAuthToken):
 | 
				
			||||||
 | 
					    serializer_class = PaperlessAuthTokenSerializer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class StandardPagination(PageNumberPagination):
 | 
					class StandardPagination(PageNumberPagination):
 | 
				
			||||||
    page_size = 25
 | 
					    page_size = 25
 | 
				
			||||||
    page_size_query_param = "page_size"
 | 
					    page_size_query_param = "page_size"
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user