mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-11-04 03:27:12 -05:00 
			
		
		
		
	Fixup some tests
This commit is contained in:
		
							parent
							
								
									9183bfc0a4
								
							
						
					
					
						commit
						4a28be233e
					
				@ -65,6 +65,8 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
 | 
			
		||||
                "barcode_enable_tag": None,
 | 
			
		||||
                "barcode_tag_mapping": None,
 | 
			
		||||
                "ai_enabled": False,
 | 
			
		||||
                "llm_embedding_backend": None,
 | 
			
		||||
                "llm_embedding_model": None,
 | 
			
		||||
                "llm_backend": None,
 | 
			
		||||
                "llm_model": None,
 | 
			
		||||
                "llm_api_key": None,
 | 
			
		||||
 | 
			
		||||
@ -37,28 +37,65 @@ class OllamaLLM(LLM):
 | 
			
		||||
            data = response.json()
 | 
			
		||||
            return CompletionResponse(text=data["response"])
 | 
			
		||||
 | 
			
		||||
    def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
 | 
			
		||||
        with httpx.Client(timeout=120.0) as client:
 | 
			
		||||
            response = client.post(
 | 
			
		||||
                f"{self.base_url}/api/generate",
 | 
			
		||||
                json={
 | 
			
		||||
                    "model": self.model,
 | 
			
		||||
                    "messages": [
 | 
			
		||||
                        {
 | 
			
		||||
                            "role": message.role,
 | 
			
		||||
                            "content": message.content,
 | 
			
		||||
                        }
 | 
			
		||||
                        for message in messages
 | 
			
		||||
                    ],
 | 
			
		||||
                    "stream": False,
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
            response.raise_for_status()
 | 
			
		||||
            data = response.json()
 | 
			
		||||
            return ChatResponse(text=data["response"])
 | 
			
		||||
 | 
			
		||||
    # -- Required stubs for ABC:
 | 
			
		||||
    def stream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen:
 | 
			
		||||
    def stream_complete(
 | 
			
		||||
        self,
 | 
			
		||||
        prompt: str,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> CompletionResponseGen:  # pragma: no cover
 | 
			
		||||
        raise NotImplementedError("stream_complete not supported")
 | 
			
		||||
 | 
			
		||||
    def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
 | 
			
		||||
        raise NotImplementedError("chat not supported")
 | 
			
		||||
 | 
			
		||||
    def stream_chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponseGen:
 | 
			
		||||
    def stream_chat(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: list[ChatMessage],
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> ChatResponseGen:  # pragma: no cover
 | 
			
		||||
        raise NotImplementedError("stream_chat not supported")
 | 
			
		||||
 | 
			
		||||
    async def achat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
 | 
			
		||||
    async def achat(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: list[ChatMessage],
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> ChatResponse:  # pragma: no cover
 | 
			
		||||
        raise NotImplementedError("async chat not supported")
 | 
			
		||||
 | 
			
		||||
    async def astream_chat(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: list[ChatMessage],
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> ChatResponseGen:
 | 
			
		||||
    ) -> ChatResponseGen:  # pragma: no cover
 | 
			
		||||
        raise NotImplementedError("async stream_chat not supported")
 | 
			
		||||
 | 
			
		||||
    async def acomplete(self, prompt: str, **kwargs) -> CompletionResponse:
 | 
			
		||||
    async def acomplete(
 | 
			
		||||
        self,
 | 
			
		||||
        prompt: str,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> CompletionResponse:  # pragma: no cover
 | 
			
		||||
        raise NotImplementedError("async complete not supported")
 | 
			
		||||
 | 
			
		||||
    async def astream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen:
 | 
			
		||||
    async def astream_complete(
 | 
			
		||||
        self,
 | 
			
		||||
        prompt: str,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> CompletionResponseGen:  # pragma: no cover
 | 
			
		||||
        raise NotImplementedError("async stream_complete not supported")
 | 
			
		||||
 | 
			
		||||
@ -1419,10 +1419,9 @@ OUTLOOK_OAUTH_ENABLED = bool(
 | 
			
		||||
AI_ENABLED = __get_boolean("PAPERLESS_AI_ENABLED", "NO")
 | 
			
		||||
LLM_EMBEDDING_BACKEND = os.getenv(
 | 
			
		||||
    "PAPERLESS_LLM_EMBEDDING_BACKEND",
 | 
			
		||||
    "local",
 | 
			
		||||
)  # or "openai"
 | 
			
		||||
)  # "local" or "openai"
 | 
			
		||||
LLM_EMBEDDING_MODEL = os.getenv("PAPERLESS_LLM_EMBEDDING_MODEL")
 | 
			
		||||
LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND", "ollama")  # or "openai"
 | 
			
		||||
LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND")  # "ollama" or "openai"
 | 
			
		||||
LLM_MODEL = os.getenv("PAPERLESS_LLM_MODEL")
 | 
			
		||||
LLM_API_KEY = os.getenv("PAPERLESS_LLM_API_KEY")
 | 
			
		||||
LLM_URL = os.getenv("PAPERLESS_LLM_URL")
 | 
			
		||||
 | 
			
		||||
@ -1,11 +1,13 @@
 | 
			
		||||
import json
 | 
			
		||||
from unittest.mock import MagicMock
 | 
			
		||||
from unittest.mock import patch
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
from django.test import override_settings
 | 
			
		||||
 | 
			
		||||
from documents.models import Document
 | 
			
		||||
from paperless.ai.ai_classifier import get_ai_document_classification
 | 
			
		||||
from paperless.ai.ai_classifier import parse_ai_classification_response
 | 
			
		||||
from paperless.ai.ai_classifier import parse_ai_response
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
@ -15,8 +17,12 @@ def mock_document():
 | 
			
		||||
 | 
			
		||||
@pytest.mark.django_db
 | 
			
		||||
@patch("paperless.ai.client.AIClient.run_llm_query")
 | 
			
		||||
@override_settings(
 | 
			
		||||
    LLM_BACKEND="ollama",
 | 
			
		||||
    LLM_MODEL="some_model",
 | 
			
		||||
)
 | 
			
		||||
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
 | 
			
		||||
    mock_response = json.dumps(
 | 
			
		||||
    mock_run_llm_query.return_value.text = json.dumps(
 | 
			
		||||
        {
 | 
			
		||||
            "title": "Test Title",
 | 
			
		||||
            "tags": ["test", "document"],
 | 
			
		||||
@ -26,7 +32,6 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
 | 
			
		||||
            "dates": ["2023-01-01"],
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
    mock_run_llm_query.return_value = mock_response
 | 
			
		||||
 | 
			
		||||
    result = get_ai_document_classification(mock_document)
 | 
			
		||||
 | 
			
		||||
@ -43,58 +48,56 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
 | 
			
		||||
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
 | 
			
		||||
    mock_run_llm_query.side_effect = Exception("LLM query failed")
 | 
			
		||||
 | 
			
		||||
    result = get_ai_document_classification(mock_document)
 | 
			
		||||
 | 
			
		||||
    assert result == {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_parse_llm_classification_response_valid():
 | 
			
		||||
    mock_response = json.dumps(
 | 
			
		||||
        {
 | 
			
		||||
            "title": "Test Title",
 | 
			
		||||
            "tags": ["test", "document"],
 | 
			
		||||
            "correspondents": ["John Doe"],
 | 
			
		||||
            "document_types": ["report"],
 | 
			
		||||
            "storage_paths": ["Reports"],
 | 
			
		||||
            "dates": ["2023-01-01"],
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    result = parse_ai_classification_response(mock_response)
 | 
			
		||||
 | 
			
		||||
    assert result["title"] == "Test Title"
 | 
			
		||||
    assert result["tags"] == ["test", "document"]
 | 
			
		||||
    assert result["correspondents"] == ["John Doe"]
 | 
			
		||||
    assert result["document_types"] == ["report"]
 | 
			
		||||
    assert result["storage_paths"] == ["Reports"]
 | 
			
		||||
    assert result["dates"] == ["2023-01-01"]
 | 
			
		||||
    # assert raises an exception
 | 
			
		||||
    with pytest.raises(Exception):
 | 
			
		||||
        get_ai_document_classification(mock_document)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_parse_llm_classification_response_invalid_json():
 | 
			
		||||
    mock_response = "Invalid JSON"
 | 
			
		||||
    mock_response = MagicMock()
 | 
			
		||||
    mock_response.text = "Invalid JSON response"
 | 
			
		||||
 | 
			
		||||
    result = parse_ai_classification_response(mock_response)
 | 
			
		||||
    result = parse_ai_response(mock_response)
 | 
			
		||||
 | 
			
		||||
    assert result == {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_parse_llm_classification_response_partial_data():
 | 
			
		||||
    mock_response = json.dumps(
 | 
			
		||||
        {
 | 
			
		||||
            "title": "Partial Data",
 | 
			
		||||
            "tags": ["partial"],
 | 
			
		||||
            "correspondents": "Jane Doe",
 | 
			
		||||
            "document_types": "note",
 | 
			
		||||
            "storage_paths": [],
 | 
			
		||||
            "dates": [],
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
@pytest.mark.django_db
 | 
			
		||||
@patch("paperless.ai.client.AIClient.run_llm_query")
 | 
			
		||||
@patch("paperless.ai.ai_classifier.build_prompt_with_rag")
 | 
			
		||||
@override_settings(
 | 
			
		||||
    LLM_EMBEDDING_BACKEND="local",
 | 
			
		||||
    LLM_EMBEDDING_MODEL="some_model",
 | 
			
		||||
    LLM_BACKEND="ollama",
 | 
			
		||||
    LLM_MODEL="some_model",
 | 
			
		||||
)
 | 
			
		||||
def test_use_rag_if_configured(
 | 
			
		||||
    mock_build_prompt_with_rag,
 | 
			
		||||
    mock_run_llm_query,
 | 
			
		||||
    mock_document,
 | 
			
		||||
):
 | 
			
		||||
    mock_build_prompt_with_rag.return_value = "Prompt with RAG"
 | 
			
		||||
    mock_run_llm_query.return_value.text = json.dumps({})
 | 
			
		||||
    get_ai_document_classification(mock_document)
 | 
			
		||||
    mock_build_prompt_with_rag.assert_called_once()
 | 
			
		||||
 | 
			
		||||
    result = parse_ai_classification_response(mock_response)
 | 
			
		||||
 | 
			
		||||
    assert result["title"] == "Partial Data"
 | 
			
		||||
    assert result["tags"] == ["partial"]
 | 
			
		||||
    assert result["correspondents"] == ["Jane Doe"]
 | 
			
		||||
    assert result["document_types"] == ["note"]
 | 
			
		||||
    assert result["storage_paths"] == []
 | 
			
		||||
    assert result["dates"] == []
 | 
			
		||||
@pytest.mark.django_db
 | 
			
		||||
@patch("paperless.ai.client.AIClient.run_llm_query")
 | 
			
		||||
@patch("paperless.ai.ai_classifier.build_prompt_without_rag")
 | 
			
		||||
@patch("paperless.config.AIConfig")
 | 
			
		||||
@override_settings(
 | 
			
		||||
    LLM_BACKEND="ollama",
 | 
			
		||||
    LLM_MODEL="some_model",
 | 
			
		||||
)
 | 
			
		||||
def test_use_without_rag_if_not_configured(
 | 
			
		||||
    mock_ai_config,
 | 
			
		||||
    mock_build_prompt_without_rag,
 | 
			
		||||
    mock_run_llm_query,
 | 
			
		||||
    mock_document,
 | 
			
		||||
):
 | 
			
		||||
    mock_ai_config.llm_embedding_backend = None
 | 
			
		||||
    mock_build_prompt_without_rag.return_value = "Prompt without RAG"
 | 
			
		||||
    mock_run_llm_query.return_value.text = json.dumps({})
 | 
			
		||||
    get_ai_document_classification(mock_document)
 | 
			
		||||
    mock_build_prompt_without_rag.assert_called_once()
 | 
			
		||||
 | 
			
		||||
@ -1,95 +1,93 @@
 | 
			
		||||
import json
 | 
			
		||||
from unittest.mock import MagicMock
 | 
			
		||||
from unittest.mock import patch
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
from django.conf import settings
 | 
			
		||||
from llama_index.core.llms import ChatMessage
 | 
			
		||||
 | 
			
		||||
from paperless.ai.client import AIClient
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def mock_settings():
 | 
			
		||||
    settings.LLM_BACKEND = "openai"
 | 
			
		||||
    settings.LLM_MODEL = "gpt-3.5-turbo"
 | 
			
		||||
    settings.LLM_API_KEY = "test-api-key"
 | 
			
		||||
    yield settings
 | 
			
		||||
def mock_ai_config():
 | 
			
		||||
    with patch("paperless.ai.client.AIConfig") as MockAIConfig:
 | 
			
		||||
        mock_config = MagicMock()
 | 
			
		||||
        MockAIConfig.return_value = mock_config
 | 
			
		||||
        yield mock_config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.django_db
 | 
			
		||||
@patch("paperless.ai.client.AIClient._run_openai_query")
 | 
			
		||||
@patch("paperless.ai.client.AIClient._run_ollama_query")
 | 
			
		||||
def test_run_llm_query_openai(mock_ollama_query, mock_openai_query, mock_settings):
 | 
			
		||||
    mock_settings.LLM_BACKEND = "openai"
 | 
			
		||||
    mock_openai_query.return_value = "OpenAI response"
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def mock_ollama_llm():
 | 
			
		||||
    with patch("paperless.ai.client.OllamaLLM") as MockOllamaLLM:
 | 
			
		||||
        yield MockOllamaLLM
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def mock_openai_llm():
 | 
			
		||||
    with patch("paperless.ai.client.OpenAI") as MockOpenAI:
 | 
			
		||||
        yield MockOpenAI
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_get_llm_ollama(mock_ai_config, mock_ollama_llm):
 | 
			
		||||
    mock_ai_config.llm_backend = "ollama"
 | 
			
		||||
    mock_ai_config.llm_model = "test_model"
 | 
			
		||||
    mock_ai_config.llm_url = "http://test-url"
 | 
			
		||||
 | 
			
		||||
    client = AIClient()
 | 
			
		||||
    result = client.run_llm_query("Test prompt")
 | 
			
		||||
    assert result == "OpenAI response"
 | 
			
		||||
    mock_openai_query.assert_called_once_with("Test prompt")
 | 
			
		||||
    mock_ollama_query.assert_not_called()
 | 
			
		||||
 | 
			
		||||
    mock_ollama_llm.assert_called_once_with(
 | 
			
		||||
        model="test_model",
 | 
			
		||||
        base_url="http://test-url",
 | 
			
		||||
    )
 | 
			
		||||
    assert client.llm == mock_ollama_llm.return_value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.django_db
 | 
			
		||||
@patch("paperless.ai.client.AIClient._run_openai_query")
 | 
			
		||||
@patch("paperless.ai.client.AIClient._run_ollama_query")
 | 
			
		||||
def test_run_llm_query_ollama(mock_ollama_query, mock_openai_query, mock_settings):
 | 
			
		||||
    mock_settings.LLM_BACKEND = "ollama"
 | 
			
		||||
    mock_ollama_query.return_value = "Ollama response"
 | 
			
		||||
def test_get_llm_openai(mock_ai_config, mock_openai_llm):
 | 
			
		||||
    mock_ai_config.llm_backend = "openai"
 | 
			
		||||
    mock_ai_config.llm_model = "test_model"
 | 
			
		||||
    mock_ai_config.openai_api_key = "test_api_key"
 | 
			
		||||
 | 
			
		||||
    client = AIClient()
 | 
			
		||||
    result = client.run_llm_query("Test prompt")
 | 
			
		||||
    assert result == "Ollama response"
 | 
			
		||||
    mock_ollama_query.assert_called_once_with("Test prompt")
 | 
			
		||||
    mock_openai_query.assert_not_called()
 | 
			
		||||
 | 
			
		||||
    mock_openai_llm.assert_called_once_with(
 | 
			
		||||
        model="test_model",
 | 
			
		||||
        api_key="test_api_key",
 | 
			
		||||
    )
 | 
			
		||||
    assert client.llm == mock_openai_llm.return_value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.django_db
 | 
			
		||||
def test_run_llm_query_unsupported_backend(mock_settings):
 | 
			
		||||
    mock_settings.LLM_BACKEND = "unsupported"
 | 
			
		||||
    client = AIClient()
 | 
			
		||||
def test_get_llm_unsupported_backend(mock_ai_config):
 | 
			
		||||
    mock_ai_config.llm_backend = "unsupported"
 | 
			
		||||
 | 
			
		||||
    with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"):
 | 
			
		||||
        client.run_llm_query("Test prompt")
 | 
			
		||||
        AIClient()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.django_db
 | 
			
		||||
def test_run_openai_query(httpx_mock, mock_settings):
 | 
			
		||||
    mock_settings.LLM_BACKEND = "openai"
 | 
			
		||||
    httpx_mock.add_response(
 | 
			
		||||
        url="https://api.openai.com/v1/chat/completions",
 | 
			
		||||
        json={
 | 
			
		||||
            "choices": [{"message": {"content": "OpenAI response"}}],
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
def test_run_llm_query(mock_ai_config, mock_ollama_llm):
 | 
			
		||||
    mock_ai_config.llm_backend = "ollama"
 | 
			
		||||
    mock_ai_config.llm_model = "test_model"
 | 
			
		||||
    mock_ai_config.llm_url = "http://test-url"
 | 
			
		||||
 | 
			
		||||
    mock_llm_instance = mock_ollama_llm.return_value
 | 
			
		||||
    mock_llm_instance.complete.return_value = "test_result"
 | 
			
		||||
 | 
			
		||||
    client = AIClient()
 | 
			
		||||
    result = client.run_llm_query("Test prompt")
 | 
			
		||||
    assert result == "OpenAI response"
 | 
			
		||||
    result = client.run_llm_query("test_prompt")
 | 
			
		||||
 | 
			
		||||
    request = httpx_mock.get_request()
 | 
			
		||||
    assert request.method == "POST"
 | 
			
		||||
    assert request.headers["Authorization"] == f"Bearer {mock_settings.LLM_API_KEY}"
 | 
			
		||||
    assert request.headers["Content-Type"] == "application/json"
 | 
			
		||||
    assert json.loads(request.content) == {
 | 
			
		||||
        "model": mock_settings.LLM_MODEL,
 | 
			
		||||
        "messages": [{"role": "user", "content": "Test prompt"}],
 | 
			
		||||
        "temperature": 0.3,
 | 
			
		||||
    }
 | 
			
		||||
    mock_llm_instance.complete.assert_called_once_with("test_prompt")
 | 
			
		||||
    assert result == "test_result"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.django_db
 | 
			
		||||
def test_run_ollama_query(httpx_mock, mock_settings):
 | 
			
		||||
    mock_settings.LLM_BACKEND = "ollama"
 | 
			
		||||
    httpx_mock.add_response(
 | 
			
		||||
        url="http://localhost:11434/api/chat",
 | 
			
		||||
        json={"message": {"content": "Ollama response"}},
 | 
			
		||||
    )
 | 
			
		||||
def test_run_chat(mock_ai_config, mock_ollama_llm):
 | 
			
		||||
    mock_ai_config.llm_backend = "ollama"
 | 
			
		||||
    mock_ai_config.llm_model = "test_model"
 | 
			
		||||
    mock_ai_config.llm_url = "http://test-url"
 | 
			
		||||
 | 
			
		||||
    mock_llm_instance = mock_ollama_llm.return_value
 | 
			
		||||
    mock_llm_instance.chat.return_value = "test_chat_result"
 | 
			
		||||
 | 
			
		||||
    client = AIClient()
 | 
			
		||||
    result = client.run_llm_query("Test prompt")
 | 
			
		||||
    assert result == "Ollama response"
 | 
			
		||||
    messages = [ChatMessage(role="user", content="Hello")]
 | 
			
		||||
    result = client.run_chat(messages)
 | 
			
		||||
 | 
			
		||||
    request = httpx_mock.get_request()
 | 
			
		||||
    assert request.method == "POST"
 | 
			
		||||
    assert json.loads(request.content) == {
 | 
			
		||||
        "model": mock_settings.LLM_MODEL,
 | 
			
		||||
        "messages": [{"role": "user", "content": "Test prompt"}],
 | 
			
		||||
        "stream": False,
 | 
			
		||||
    }
 | 
			
		||||
    mock_llm_instance.chat.assert_called_once_with(messages)
 | 
			
		||||
    assert result == "test_chat_result"
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user