mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-05-24 02:02:23 -04:00
118 lines
3.6 KiB
Python
118 lines
3.6 KiB
Python
import json
|
|
from unittest.mock import MagicMock
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
from django.test import override_settings
|
|
|
|
from documents.models import Document
|
|
from paperless.ai.ai_classifier import build_prompt_with_rag
|
|
from paperless.ai.ai_classifier import build_prompt_without_rag
|
|
from paperless.ai.ai_classifier import get_ai_document_classification
|
|
from paperless.ai.ai_classifier import parse_ai_response
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_document():
|
|
return Document(filename="test.pdf", content="This is a test document content.")
|
|
|
|
|
|
@pytest.mark.django_db
|
|
@patch("paperless.ai.client.AIClient.run_llm_query")
|
|
@override_settings(
|
|
LLM_BACKEND="ollama",
|
|
LLM_MODEL="some_model",
|
|
)
|
|
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
|
|
mock_run_llm_query.return_value.text = json.dumps(
|
|
{
|
|
"title": "Test Title",
|
|
"tags": ["test", "document"],
|
|
"correspondents": ["John Doe"],
|
|
"document_types": ["report"],
|
|
"storage_paths": ["Reports"],
|
|
"dates": ["2023-01-01"],
|
|
},
|
|
)
|
|
|
|
result = get_ai_document_classification(mock_document)
|
|
|
|
assert result["title"] == "Test Title"
|
|
assert result["tags"] == ["test", "document"]
|
|
assert result["correspondents"] == ["John Doe"]
|
|
assert result["document_types"] == ["report"]
|
|
assert result["storage_paths"] == ["Reports"]
|
|
assert result["dates"] == ["2023-01-01"]
|
|
|
|
|
|
@pytest.mark.django_db
|
|
@patch("paperless.ai.client.AIClient.run_llm_query")
|
|
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
|
|
mock_run_llm_query.side_effect = Exception("LLM query failed")
|
|
|
|
# assert raises an exception
|
|
with pytest.raises(Exception):
|
|
get_ai_document_classification(mock_document)
|
|
|
|
|
|
def test_parse_llm_classification_response_invalid_json():
|
|
mock_response = MagicMock()
|
|
mock_response.text = "Invalid JSON response"
|
|
|
|
result = parse_ai_response(mock_response)
|
|
|
|
assert result == {}
|
|
|
|
|
|
@pytest.mark.django_db
|
|
@patch("paperless.ai.client.AIClient.run_llm_query")
|
|
@patch("paperless.ai.ai_classifier.build_prompt_with_rag")
|
|
@override_settings(
|
|
LLM_EMBEDDING_BACKEND="huggingface",
|
|
LLM_EMBEDDING_MODEL="some_model",
|
|
LLM_BACKEND="ollama",
|
|
LLM_MODEL="some_model",
|
|
)
|
|
def test_use_rag_if_configured(
|
|
mock_build_prompt_with_rag,
|
|
mock_run_llm_query,
|
|
mock_document,
|
|
):
|
|
mock_build_prompt_with_rag.return_value = "Prompt with RAG"
|
|
mock_run_llm_query.return_value.text = json.dumps({})
|
|
get_ai_document_classification(mock_document)
|
|
mock_build_prompt_with_rag.assert_called_once()
|
|
|
|
|
|
@pytest.mark.django_db
|
|
@patch("paperless.ai.client.AIClient.run_llm_query")
|
|
@patch("paperless.ai.ai_classifier.build_prompt_without_rag")
|
|
@patch("paperless.config.AIConfig")
|
|
@override_settings(
|
|
LLM_BACKEND="ollama",
|
|
LLM_MODEL="some_model",
|
|
)
|
|
def test_use_without_rag_if_not_configured(
|
|
mock_ai_config,
|
|
mock_build_prompt_without_rag,
|
|
mock_run_llm_query,
|
|
mock_document,
|
|
):
|
|
mock_ai_config.llm_embedding_backend = None
|
|
mock_build_prompt_without_rag.return_value = "Prompt without RAG"
|
|
mock_run_llm_query.return_value.text = json.dumps({})
|
|
get_ai_document_classification(mock_document)
|
|
mock_build_prompt_without_rag.assert_called_once()
|
|
|
|
|
|
@override_settings(
|
|
LLM_BACKEND="ollama",
|
|
LLM_MODEL="some_model",
|
|
)
|
|
def test_prompt_with_without_rag(mock_document):
|
|
prompt = build_prompt_without_rag(mock_document)
|
|
assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt
|
|
|
|
prompt = build_prompt_with_rag(mock_document)
|
|
assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
|