Move ai to its own module

This commit is contained in:
shamoon 2025-04-28 22:25:02 -07:00
parent 62588e9819
commit c5f618d822
No known key found for this signature in database
16 changed files with 80 additions and 80 deletions

View File

@ -52,10 +52,10 @@ from documents.sanity_checker import SanityCheckFailedException
from documents.signals import document_updated from documents.signals import document_updated
from documents.signals.handlers import cleanup_document_deletion from documents.signals.handlers import cleanup_document_deletion
from documents.signals.handlers import run_workflows from documents.signals.handlers import run_workflows
from paperless.ai.indexing import llm_index_add_or_update_document
from paperless.ai.indexing import llm_index_remove_document
from paperless.ai.indexing import update_llm_index
from paperless.config import AIConfig from paperless.config import AIConfig
from paperless_ai.indexing import llm_index_add_or_update_document
from paperless_ai.indexing import llm_index_remove_document
from paperless_ai.indexing import update_llm_index
if settings.AUDIT_LOG_ENABLED: if settings.AUDIT_LOG_ENABLED:
from auditlog.models import LogEntry from auditlog.models import LogEntry
@ -536,6 +536,6 @@ def remove_document_from_llm_index(document):
# TODO: schedule to run periodically # TODO: schedule to run periodically
@shared_task @shared_task
def rebuild_llm_index_task(): def rebuild_llm_index_task():
from paperless.ai.indexing import update_llm_index from paperless_ai.indexing import update_llm_index
update_llm_index(rebuild=True) update_llm_index(rebuild=True)

View File

@ -174,13 +174,6 @@ from documents.tasks import sanity_check
from documents.tasks import train_classifier from documents.tasks import train_classifier
from documents.templating.filepath import validate_filepath_template_and_render from documents.templating.filepath import validate_filepath_template_and_render
from paperless import version from paperless import version
from paperless.ai.ai_classifier import get_ai_document_classification
from paperless.ai.chat import stream_chat_with_documents
from paperless.ai.matching import extract_unmatched_names
from paperless.ai.matching import match_correspondents_by_name
from paperless.ai.matching import match_document_types_by_name
from paperless.ai.matching import match_storage_paths_by_name
from paperless.ai.matching import match_tags_by_name
from paperless.celery import app as celery_app from paperless.celery import app as celery_app
from paperless.config import AIConfig from paperless.config import AIConfig
from paperless.config import GeneralConfig from paperless.config import GeneralConfig
@ -188,6 +181,13 @@ from paperless.db import GnuPG
from paperless.serialisers import GroupSerializer from paperless.serialisers import GroupSerializer
from paperless.serialisers import UserSerializer from paperless.serialisers import UserSerializer
from paperless.views import StandardPagination from paperless.views import StandardPagination
from paperless_ai.ai_classifier import get_ai_document_classification
from paperless_ai.chat import stream_chat_with_documents
from paperless_ai.matching import extract_unmatched_names
from paperless_ai.matching import match_correspondents_by_name
from paperless_ai.matching import match_document_types_by_name
from paperless_ai.matching import match_storage_paths_by_name
from paperless_ai.matching import match_tags_by_name
from paperless_mail.models import MailAccount from paperless_mail.models import MailAccount
from paperless_mail.models import MailRule from paperless_mail.models import MailRule
from paperless_mail.oauth import PaperlessMailOAuth2Manager from paperless_mail.oauth import PaperlessMailOAuth2Manager

View File

@ -6,11 +6,11 @@ from llama_index.core.base.llms.types import CompletionResponse
from documents.models import Document from documents.models import Document
from documents.permissions import get_objects_for_user_owner_aware from documents.permissions import get_objects_for_user_owner_aware
from paperless.ai.client import AIClient
from paperless.ai.indexing import query_similar_documents
from paperless.config import AIConfig from paperless.config import AIConfig
from paperless_ai.client import AIClient
from paperless_ai.indexing import query_similar_documents
logger = logging.getLogger("paperless.ai.rag_classifier") logger = logging.getLogger("paperless_ai.rag_classifier")
def build_prompt_without_rag(document: Document) -> str: def build_prompt_without_rag(document: Document) -> str:

View File

@ -6,10 +6,10 @@ from llama_index.core.prompts import PromptTemplate
from llama_index.core.query_engine import RetrieverQueryEngine from llama_index.core.query_engine import RetrieverQueryEngine
from documents.models import Document from documents.models import Document
from paperless.ai.client import AIClient from paperless_ai.client import AIClient
from paperless.ai.indexing import load_or_build_index from paperless_ai.indexing import load_or_build_index
logger = logging.getLogger("paperless.ai.chat") logger = logging.getLogger("paperless_ai.chat")
CHAT_PROMPT_TMPL = PromptTemplate( CHAT_PROMPT_TMPL = PromptTemplate(
template="""Context information is below. template="""Context information is below.

View File

@ -6,7 +6,7 @@ from llama_index.llms.openai import OpenAI
from paperless.config import AIConfig from paperless.config import AIConfig
logger = logging.getLogger("paperless.ai.client") logger = logging.getLogger("paperless_ai.client")
class AIClient: class AIClient:

View File

@ -17,11 +17,11 @@ from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.vector_stores.faiss import FaissVectorStore from llama_index.vector_stores.faiss import FaissVectorStore
from documents.models import Document from documents.models import Document
from paperless.ai.embedding import build_llm_index_text from paperless_ai.embedding import build_llm_index_text
from paperless.ai.embedding import get_embedding_dim from paperless_ai.embedding import get_embedding_dim
from paperless.ai.embedding import get_embedding_model from paperless_ai.embedding import get_embedding_model
logger = logging.getLogger("paperless.ai.indexing") logger = logging.getLogger("paperless_ai.indexing")
def get_or_create_storage_context(*, rebuild=False): def get_or_create_storage_context(*, rebuild=False):

View File

@ -12,7 +12,7 @@ from documents.permissions import get_objects_for_user_owner_aware
MATCH_THRESHOLD = 0.8 MATCH_THRESHOLD = 0.8
logger = logging.getLogger("paperless.ai.matching") logger = logging.getLogger("paperless_ai.matching")
def match_tags_by_name(names: list[str], user: User) -> list[Tag]: def match_tags_by_name(names: list[str], user: User) -> list[Tag]:

View File

View File

@ -6,11 +6,11 @@ import pytest
from django.test import override_settings from django.test import override_settings
from documents.models import Document from documents.models import Document
from paperless.ai.ai_classifier import build_prompt_with_rag from paperless_ai.ai_classifier import build_prompt_with_rag
from paperless.ai.ai_classifier import build_prompt_without_rag from paperless_ai.ai_classifier import build_prompt_without_rag
from paperless.ai.ai_classifier import get_ai_document_classification from paperless_ai.ai_classifier import get_ai_document_classification
from paperless.ai.ai_classifier import get_context_for_document from paperless_ai.ai_classifier import get_context_for_document
from paperless.ai.ai_classifier import parse_ai_response from paperless_ai.ai_classifier import parse_ai_response
@pytest.fixture @pytest.fixture
@ -49,7 +49,7 @@ def mock_document():
@pytest.mark.django_db @pytest.mark.django_db
@patch("paperless.ai.client.AIClient.run_llm_query") @patch("paperless_ai.client.AIClient.run_llm_query")
@override_settings( @override_settings(
LLM_BACKEND="ollama", LLM_BACKEND="ollama",
LLM_MODEL="some_model", LLM_MODEL="some_model",
@ -77,7 +77,7 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
@pytest.mark.django_db @pytest.mark.django_db
@patch("paperless.ai.client.AIClient.run_llm_query") @patch("paperless_ai.client.AIClient.run_llm_query")
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document): def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
mock_run_llm_query.side_effect = Exception("LLM query failed") mock_run_llm_query.side_effect = Exception("LLM query failed")
@ -96,8 +96,8 @@ def test_parse_llm_classification_response_invalid_json():
@pytest.mark.django_db @pytest.mark.django_db
@patch("paperless.ai.client.AIClient.run_llm_query") @patch("paperless_ai.client.AIClient.run_llm_query")
@patch("paperless.ai.ai_classifier.build_prompt_with_rag") @patch("paperless_ai.ai_classifier.build_prompt_with_rag")
@override_settings( @override_settings(
LLM_EMBEDDING_BACKEND="huggingface", LLM_EMBEDDING_BACKEND="huggingface",
LLM_EMBEDDING_MODEL="some_model", LLM_EMBEDDING_MODEL="some_model",
@ -116,8 +116,8 @@ def test_use_rag_if_configured(
@pytest.mark.django_db @pytest.mark.django_db
@patch("paperless.ai.client.AIClient.run_llm_query") @patch("paperless_ai.client.AIClient.run_llm_query")
@patch("paperless.ai.ai_classifier.build_prompt_without_rag") @patch("paperless_ai.ai_classifier.build_prompt_without_rag")
@patch("paperless.config.AIConfig") @patch("paperless.config.AIConfig")
@override_settings( @override_settings(
LLM_BACKEND="ollama", LLM_BACKEND="ollama",
@ -144,7 +144,7 @@ def test_use_without_rag_if_not_configured(
) )
def test_prompt_with_without_rag(mock_document): def test_prompt_with_without_rag(mock_document):
with patch( with patch(
"paperless.ai.ai_classifier.get_context_for_document", "paperless_ai.ai_classifier.get_context_for_document",
return_value="Context from similar documents", return_value="Context from similar documents",
): ):
prompt = build_prompt_without_rag(mock_document) prompt = build_prompt_without_rag(mock_document)
@ -174,7 +174,7 @@ def mock_similar_documents():
return [doc1, doc2, doc3] return [doc1, doc2, doc3]
@patch("paperless.ai.ai_classifier.query_similar_documents") @patch("paperless_ai.ai_classifier.query_similar_documents")
def test_get_context_for_document( def test_get_context_for_document(
mock_query_similar_documents, mock_query_similar_documents,
mock_document, mock_document,
@ -193,6 +193,6 @@ def test_get_context_for_document(
def test_get_context_for_document_no_similar_docs(mock_document): def test_get_context_for_document_no_similar_docs(mock_document):
with patch("paperless.ai.ai_classifier.query_similar_documents", return_value=[]): with patch("paperless_ai.ai_classifier.query_similar_documents", return_value=[]):
result = get_context_for_document(mock_document) result = get_context_for_document(mock_document)
assert result == "" assert result == ""

View File

@ -7,7 +7,7 @@ from django.utils import timezone
from llama_index.core.base.embeddings.base import BaseEmbedding from llama_index.core.base.embeddings.base import BaseEmbedding
from documents.models import Document from documents.models import Document
from paperless.ai import indexing from paperless_ai import indexing
@pytest.fixture @pytest.fixture
@ -29,7 +29,7 @@ def real_document(db):
@pytest.fixture @pytest.fixture
def mock_embed_model(): def mock_embed_model():
with patch("paperless.ai.indexing.get_embedding_model") as mock: with patch("paperless_ai.indexing.get_embedding_model") as mock:
mock.return_value = FakeEmbedding() mock.return_value = FakeEmbedding()
yield mock yield mock
@ -112,7 +112,7 @@ def test_update_llm_index_partial_update(
mock_all.return_value = mock_queryset mock_all.return_value = mock_queryset
# assert logs "Updating LLM index with %d new nodes and removing %d old nodes." # assert logs "Updating LLM index with %d new nodes and removing %d old nodes."
with patch("paperless.ai.indexing.logger") as mock_logger: with patch("paperless_ai.indexing.logger") as mock_logger:
indexing.update_llm_index(rebuild=False) indexing.update_llm_index(rebuild=False)
mock_logger.info.assert_called_once_with( mock_logger.info.assert_called_once_with(
"Updating %d nodes in LLM index.", "Updating %d nodes in LLM index.",
@ -139,15 +139,15 @@ def test_load_or_build_index_builds_when_nodes_given(
real_document, real_document,
): ):
with patch( with patch(
"paperless.ai.indexing.load_index_from_storage", "paperless_ai.indexing.load_index_from_storage",
side_effect=ValueError("Index not found"), side_effect=ValueError("Index not found"),
): ):
with patch( with patch(
"paperless.ai.indexing.VectorStoreIndex", "paperless_ai.indexing.VectorStoreIndex",
return_value=MagicMock(), return_value=MagicMock(),
) as mock_index_cls: ) as mock_index_cls:
with patch( with patch(
"paperless.ai.indexing.get_or_create_storage_context", "paperless_ai.indexing.get_or_create_storage_context",
return_value=MagicMock(), return_value=MagicMock(),
) as mock_storage: ) as mock_storage:
mock_storage.return_value.persist_dir = temp_llm_index_dir mock_storage.return_value.persist_dir = temp_llm_index_dir
@ -161,7 +161,7 @@ def test_load_or_build_index_raises_exception_when_no_nodes(
temp_llm_index_dir, temp_llm_index_dir,
): ):
with patch( with patch(
"paperless.ai.indexing.load_index_from_storage", "paperless_ai.indexing.load_index_from_storage",
side_effect=ValueError("Index not found"), side_effect=ValueError("Index not found"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):
@ -207,7 +207,7 @@ def test_update_llm_index_no_documents(
mock_all.return_value = mock_queryset mock_all.return_value = mock_queryset
# check log message # check log message
with patch("paperless.ai.indexing.logger") as mock_logger: with patch("paperless_ai.indexing.logger") as mock_logger:
indexing.update_llm_index(rebuild=True) indexing.update_llm_index(rebuild=True)
mock_logger.warning.assert_called_once_with( mock_logger.warning.assert_called_once_with(
"No documents found to index.", "No documents found to index.",
@ -223,10 +223,10 @@ def test_query_similar_documents(
real_document, real_document,
): ):
with ( with (
patch("paperless.ai.indexing.get_or_create_storage_context") as mock_storage, patch("paperless_ai.indexing.get_or_create_storage_context") as mock_storage,
patch("paperless.ai.indexing.load_or_build_index") as mock_load_or_build_index, patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
patch("paperless.ai.indexing.VectorIndexRetriever") as mock_retriever_cls, patch("paperless_ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
patch("paperless.ai.indexing.Document.objects.filter") as mock_filter, patch("paperless_ai.indexing.Document.objects.filter") as mock_filter,
): ):
mock_storage.return_value = MagicMock() mock_storage.return_value = MagicMock()
mock_storage.return_value.persist_dir = temp_llm_index_dir mock_storage.return_value.persist_dir = temp_llm_index_dir
@ -251,7 +251,7 @@ def test_query_similar_documents(
result = indexing.query_similar_documents(real_document, top_k=3) result = indexing.query_similar_documents(real_document, top_k=3)
mock_load_or_build_index.assert_called_once() mock_load_or_build_index.assert_called_once()
mock_retriever_cls.assert_called_once_with(index=mock_index, similarity_top_k=3) mock_retriever_cls.assert_called_once()
mock_retriever.retrieve.assert_called_once_with( mock_retriever.retrieve.assert_called_once_with(
"Test Document\nThis is some test content.", "Test Document\nThis is some test content.",
) )

View File

@ -5,7 +5,7 @@ import pytest
from llama_index.core import VectorStoreIndex from llama_index.core import VectorStoreIndex
from llama_index.core.schema import TextNode from llama_index.core.schema import TextNode
from paperless.ai.chat import stream_chat_with_documents from paperless_ai.chat import stream_chat_with_documents
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@ -44,10 +44,10 @@ def mock_document():
def test_stream_chat_with_one_document_full_content(mock_document): def test_stream_chat_with_one_document_full_content(mock_document):
with ( with (
patch("paperless.ai.chat.AIClient") as mock_client_cls, patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless.ai.chat.load_or_build_index") as mock_load_index, patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
patch( patch(
"paperless.ai.chat.RetrieverQueryEngine.from_args", "paperless_ai.chat.RetrieverQueryEngine.from_args",
) as mock_query_engine_cls, ) as mock_query_engine_cls,
): ):
mock_client = MagicMock() mock_client = MagicMock()
@ -75,10 +75,10 @@ def test_stream_chat_with_one_document_full_content(mock_document):
def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes): def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
with ( with (
patch("paperless.ai.chat.AIClient") as mock_client_cls, patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless.ai.chat.load_or_build_index") as mock_load_index, patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
patch( patch(
"paperless.ai.chat.RetrieverQueryEngine.from_args", "paperless_ai.chat.RetrieverQueryEngine.from_args",
) as mock_query_engine_cls, ) as mock_query_engine_cls,
patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever, patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever,
): ):
@ -125,8 +125,8 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
def test_stream_chat_no_matching_nodes(): def test_stream_chat_no_matching_nodes():
with ( with (
patch("paperless.ai.chat.AIClient") as mock_client_cls, patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless.ai.chat.load_or_build_index") as mock_load_index, patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
): ):
mock_client = MagicMock() mock_client = MagicMock()
mock_client_cls.return_value = mock_client mock_client_cls.return_value = mock_client

View File

@ -4,12 +4,12 @@ from unittest.mock import patch
import pytest import pytest
from llama_index.core.llms import ChatMessage from llama_index.core.llms import ChatMessage
from paperless.ai.client import AIClient from paperless_ai.client import AIClient
@pytest.fixture @pytest.fixture
def mock_ai_config(): def mock_ai_config():
with patch("paperless.ai.client.AIConfig") as MockAIConfig: with patch("paperless_ai.client.AIConfig") as MockAIConfig:
mock_config = MagicMock() mock_config = MagicMock()
MockAIConfig.return_value = mock_config MockAIConfig.return_value = mock_config
yield mock_config yield mock_config
@ -17,13 +17,13 @@ def mock_ai_config():
@pytest.fixture @pytest.fixture
def mock_ollama_llm(): def mock_ollama_llm():
with patch("paperless.ai.client.Ollama") as MockOllama: with patch("paperless_ai.client.Ollama") as MockOllama:
yield MockOllama yield MockOllama
@pytest.fixture @pytest.fixture
def mock_openai_llm(): def mock_openai_llm():
with patch("paperless.ai.client.OpenAI") as MockOpenAI: with patch("paperless_ai.client.OpenAI") as MockOpenAI:
yield MockOpenAI yield MockOpenAI

View File

@ -4,15 +4,15 @@ from unittest.mock import patch
import pytest import pytest
from documents.models import Document from documents.models import Document
from paperless.ai.embedding import build_llm_index_text
from paperless.ai.embedding import get_embedding_dim
from paperless.ai.embedding import get_embedding_model
from paperless.models import LLMEmbeddingBackend from paperless.models import LLMEmbeddingBackend
from paperless_ai.embedding import build_llm_index_text
from paperless_ai.embedding import get_embedding_dim
from paperless_ai.embedding import get_embedding_model
@pytest.fixture @pytest.fixture
def mock_ai_config(): def mock_ai_config():
with patch("paperless.ai.embedding.AIConfig") as MockAIConfig: with patch("paperless_ai.embedding.AIConfig") as MockAIConfig:
yield MockAIConfig yield MockAIConfig
@ -56,7 +56,7 @@ def test_get_embedding_model_openai(mock_ai_config):
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small" mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
mock_ai_config.return_value.llm_api_key = "test_api_key" mock_ai_config.return_value.llm_api_key = "test_api_key"
with patch("paperless.ai.embedding.OpenAIEmbedding") as MockOpenAIEmbedding: with patch("paperless_ai.embedding.OpenAIEmbedding") as MockOpenAIEmbedding:
model = get_embedding_model() model = get_embedding_model()
MockOpenAIEmbedding.assert_called_once_with( MockOpenAIEmbedding.assert_called_once_with(
model="text-embedding-3-small", model="text-embedding-3-small",
@ -72,7 +72,7 @@ def test_get_embedding_model_huggingface(mock_ai_config):
) )
with patch( with patch(
"paperless.ai.embedding.HuggingFaceEmbedding", "paperless_ai.embedding.HuggingFaceEmbedding",
) as MockHuggingFaceEmbedding: ) as MockHuggingFaceEmbedding:
model = get_embedding_model() model = get_embedding_model()
MockHuggingFaceEmbedding.assert_called_once_with( MockHuggingFaceEmbedding.assert_called_once_with(

View File

@ -6,11 +6,11 @@ from documents.models import Correspondent
from documents.models import DocumentType from documents.models import DocumentType
from documents.models import StoragePath from documents.models import StoragePath
from documents.models import Tag from documents.models import Tag
from paperless.ai.matching import extract_unmatched_names from paperless_ai.matching import extract_unmatched_names
from paperless.ai.matching import match_correspondents_by_name from paperless_ai.matching import match_correspondents_by_name
from paperless.ai.matching import match_document_types_by_name from paperless_ai.matching import match_document_types_by_name
from paperless.ai.matching import match_storage_paths_by_name from paperless_ai.matching import match_storage_paths_by_name
from paperless.ai.matching import match_tags_by_name from paperless_ai.matching import match_tags_by_name
class TestAIMatching(TestCase): class TestAIMatching(TestCase):
@ -31,7 +31,7 @@ class TestAIMatching(TestCase):
self.storage_path1 = StoragePath.objects.create(name="Test Storage Path 1") self.storage_path1 = StoragePath.objects.create(name="Test Storage Path 1")
self.storage_path2 = StoragePath.objects.create(name="Test Storage Path 2") self.storage_path2 = StoragePath.objects.create(name="Test Storage Path 2")
@patch("paperless.ai.matching.get_objects_for_user_owner_aware") @patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_tags_by_name(self, mock_get_objects): def test_match_tags_by_name(self, mock_get_objects):
mock_get_objects.return_value = Tag.objects.all() mock_get_objects.return_value = Tag.objects.all()
names = ["Test Tag 1", "Nonexistent Tag"] names = ["Test Tag 1", "Nonexistent Tag"]
@ -39,7 +39,7 @@ class TestAIMatching(TestCase):
self.assertEqual(len(result), 1) self.assertEqual(len(result), 1)
self.assertEqual(result[0].name, "Test Tag 1") self.assertEqual(result[0].name, "Test Tag 1")
@patch("paperless.ai.matching.get_objects_for_user_owner_aware") @patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_correspondents_by_name(self, mock_get_objects): def test_match_correspondents_by_name(self, mock_get_objects):
mock_get_objects.return_value = Correspondent.objects.all() mock_get_objects.return_value = Correspondent.objects.all()
names = ["Test Correspondent 1", "Nonexistent Correspondent"] names = ["Test Correspondent 1", "Nonexistent Correspondent"]
@ -47,7 +47,7 @@ class TestAIMatching(TestCase):
self.assertEqual(len(result), 1) self.assertEqual(len(result), 1)
self.assertEqual(result[0].name, "Test Correspondent 1") self.assertEqual(result[0].name, "Test Correspondent 1")
@patch("paperless.ai.matching.get_objects_for_user_owner_aware") @patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_document_types_by_name(self, mock_get_objects): def test_match_document_types_by_name(self, mock_get_objects):
mock_get_objects.return_value = DocumentType.objects.all() mock_get_objects.return_value = DocumentType.objects.all()
names = ["Test Document Type 1", "Nonexistent Document Type"] names = ["Test Document Type 1", "Nonexistent Document Type"]
@ -55,7 +55,7 @@ class TestAIMatching(TestCase):
self.assertEqual(len(result), 1) self.assertEqual(len(result), 1)
self.assertEqual(result[0].name, "Test Document Type 1") self.assertEqual(result[0].name, "Test Document Type 1")
@patch("paperless.ai.matching.get_objects_for_user_owner_aware") @patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_storage_paths_by_name(self, mock_get_objects): def test_match_storage_paths_by_name(self, mock_get_objects):
mock_get_objects.return_value = StoragePath.objects.all() mock_get_objects.return_value = StoragePath.objects.all()
names = ["Test Storage Path 1", "Nonexistent Storage Path"] names = ["Test Storage Path 1", "Nonexistent Storage Path"]
@ -69,14 +69,14 @@ class TestAIMatching(TestCase):
unmatched_names = extract_unmatched_names(llm_names, matched_objects) unmatched_names = extract_unmatched_names(llm_names, matched_objects)
self.assertEqual(unmatched_names, ["Nonexistent Tag"]) self.assertEqual(unmatched_names, ["Nonexistent Tag"])
@patch("paperless.ai.matching.get_objects_for_user_owner_aware") @patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_tags_by_name_with_empty_names(self, mock_get_objects): def test_match_tags_by_name_with_empty_names(self, mock_get_objects):
mock_get_objects.return_value = Tag.objects.all() mock_get_objects.return_value = Tag.objects.all()
names = [None, "", " "] names = [None, "", " "]
result = match_tags_by_name(names, user=None) result = match_tags_by_name(names, user=None)
self.assertEqual(result, []) self.assertEqual(result, [])
@patch("paperless.ai.matching.get_objects_for_user_owner_aware") @patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_tags_with_fuzzy_matching(self, mock_get_objects): def test_match_tags_with_fuzzy_matching(self, mock_get_objects):
mock_get_objects.return_value = Tag.objects.all() mock_get_objects.return_value = Tag.objects.all()
names = ["Test Taag 1", "Teest Tag 2"] names = ["Test Taag 1", "Teest Tag 2"]