mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-11-04 03:27:12 -05:00 
			
		
		
		
	Fix: disable API basic auth if MFA enabled (#8792)
This commit is contained in:
		
							parent
							
								
									29726c3ce1
								
							
						
					
					
						commit
						5e3ee3a80d
					
				@ -1,4 +1,6 @@
 | 
				
			|||||||
 | 
					import base64
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
 | 
					from unittest import mock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from allauth.mfa.models import Authenticator
 | 
					from allauth.mfa.models import Authenticator
 | 
				
			||||||
from django.contrib.auth.models import Group
 | 
					from django.contrib.auth.models import Group
 | 
				
			||||||
@ -462,6 +464,30 @@ class TestApiAuth(DirectoriesMixin, APITestCase):
 | 
				
			|||||||
        self.assertNotIn("user_can_change", results[0])
 | 
					        self.assertNotIn("user_can_change", results[0])
 | 
				
			||||||
        self.assertNotIn("is_shared_by_requester", results[0])
 | 
					        self.assertNotIn("is_shared_by_requester", results[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @mock.patch("allauth.mfa.adapter.DefaultMFAAdapter.is_mfa_enabled")
 | 
				
			||||||
 | 
					    def test_basic_auth_mfa_enabled(self, mock_is_mfa_enabled):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        GIVEN:
 | 
				
			||||||
 | 
					            - User with MFA enabled
 | 
				
			||||||
 | 
					        WHEN:
 | 
				
			||||||
 | 
					            - API request is made with basic auth
 | 
				
			||||||
 | 
					        THEN:
 | 
				
			||||||
 | 
					            - MFA required error is returned
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        user1 = User.objects.create_user(username="user1")
 | 
				
			||||||
 | 
					        user1.set_password("password")
 | 
				
			||||||
 | 
					        user1.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        mock_is_mfa_enabled.return_value = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        response = self.client.get(
 | 
				
			||||||
 | 
					            "/api/documents/",
 | 
				
			||||||
 | 
					            HTTP_AUTHORIZATION="Basic " + base64.b64encode(b"user1:password").decode(),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
 | 
				
			||||||
 | 
					        self.assertEqual(response.data["detail"], "MFA required")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestApiUser(DirectoriesMixin, APITestCase):
 | 
					class TestApiUser(DirectoriesMixin, APITestCase):
 | 
				
			||||||
    ENDPOINT = "/api/users/"
 | 
					    ENDPOINT = "/api/users/"
 | 
				
			||||||
 | 
				
			|||||||
@ -1,5 +1,6 @@
 | 
				
			|||||||
import logging
 | 
					import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from allauth.mfa.adapter import get_adapter as get_mfa_adapter
 | 
				
			||||||
from django.conf import settings
 | 
					from django.conf import settings
 | 
				
			||||||
from django.contrib import auth
 | 
					from django.contrib import auth
 | 
				
			||||||
from django.contrib.auth.middleware import PersistentRemoteUserMiddleware
 | 
					from django.contrib.auth.middleware import PersistentRemoteUserMiddleware
 | 
				
			||||||
@ -7,6 +8,7 @@ from django.contrib.auth.models import User
 | 
				
			|||||||
from django.http import HttpRequest
 | 
					from django.http import HttpRequest
 | 
				
			||||||
from django.utils.deprecation import MiddlewareMixin
 | 
					from django.utils.deprecation import MiddlewareMixin
 | 
				
			||||||
from rest_framework import authentication
 | 
					from rest_framework import authentication
 | 
				
			||||||
 | 
					from rest_framework import exceptions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.getLogger("paperless.auth")
 | 
					logger = logging.getLogger("paperless.auth")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -70,3 +72,14 @@ class PaperlessRemoteUserAuthentication(authentication.RemoteUserAuthentication)
 | 
				
			|||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    header = settings.HTTP_REMOTE_USER_HEADER_NAME
 | 
					    header = settings.HTTP_REMOTE_USER_HEADER_NAME
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PaperlessBasicAuthentication(authentication.BasicAuthentication):
 | 
				
			||||||
 | 
					    def authenticate(self, request):
 | 
				
			||||||
 | 
					        user_tuple = super().authenticate(request)
 | 
				
			||||||
 | 
					        user = user_tuple[0] if user_tuple else None
 | 
				
			||||||
 | 
					        mfa_adapter = get_mfa_adapter()
 | 
				
			||||||
 | 
					        if user and mfa_adapter.is_mfa_enabled(user):
 | 
				
			||||||
 | 
					            raise exceptions.AuthenticationFailed("MFA required")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return user_tuple
 | 
				
			||||||
 | 
				
			|||||||
@ -336,7 +336,7 @@ if DEBUG:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
REST_FRAMEWORK = {
 | 
					REST_FRAMEWORK = {
 | 
				
			||||||
    "DEFAULT_AUTHENTICATION_CLASSES": [
 | 
					    "DEFAULT_AUTHENTICATION_CLASSES": [
 | 
				
			||||||
        "rest_framework.authentication.BasicAuthentication",
 | 
					        "paperless.auth.PaperlessBasicAuthentication",
 | 
				
			||||||
        "rest_framework.authentication.TokenAuthentication",
 | 
					        "rest_framework.authentication.TokenAuthentication",
 | 
				
			||||||
        "rest_framework.authentication.SessionAuthentication",
 | 
					        "rest_framework.authentication.SessionAuthentication",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user