mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-05-24 02:02:23 -04:00
Move ai to its own module
This commit is contained in:
parent
62588e9819
commit
c5f618d822
@ -52,10 +52,10 @@ from documents.sanity_checker import SanityCheckFailedException
|
|||||||
from documents.signals import document_updated
|
from documents.signals import document_updated
|
||||||
from documents.signals.handlers import cleanup_document_deletion
|
from documents.signals.handlers import cleanup_document_deletion
|
||||||
from documents.signals.handlers import run_workflows
|
from documents.signals.handlers import run_workflows
|
||||||
from paperless.ai.indexing import llm_index_add_or_update_document
|
|
||||||
from paperless.ai.indexing import llm_index_remove_document
|
|
||||||
from paperless.ai.indexing import update_llm_index
|
|
||||||
from paperless.config import AIConfig
|
from paperless.config import AIConfig
|
||||||
|
from paperless_ai.indexing import llm_index_add_or_update_document
|
||||||
|
from paperless_ai.indexing import llm_index_remove_document
|
||||||
|
from paperless_ai.indexing import update_llm_index
|
||||||
|
|
||||||
if settings.AUDIT_LOG_ENABLED:
|
if settings.AUDIT_LOG_ENABLED:
|
||||||
from auditlog.models import LogEntry
|
from auditlog.models import LogEntry
|
||||||
@ -536,6 +536,6 @@ def remove_document_from_llm_index(document):
|
|||||||
# TODO: schedule to run periodically
|
# TODO: schedule to run periodically
|
||||||
@shared_task
|
@shared_task
|
||||||
def rebuild_llm_index_task():
|
def rebuild_llm_index_task():
|
||||||
from paperless.ai.indexing import update_llm_index
|
from paperless_ai.indexing import update_llm_index
|
||||||
|
|
||||||
update_llm_index(rebuild=True)
|
update_llm_index(rebuild=True)
|
||||||
|
@ -174,13 +174,6 @@ from documents.tasks import sanity_check
|
|||||||
from documents.tasks import train_classifier
|
from documents.tasks import train_classifier
|
||||||
from documents.templating.filepath import validate_filepath_template_and_render
|
from documents.templating.filepath import validate_filepath_template_and_render
|
||||||
from paperless import version
|
from paperless import version
|
||||||
from paperless.ai.ai_classifier import get_ai_document_classification
|
|
||||||
from paperless.ai.chat import stream_chat_with_documents
|
|
||||||
from paperless.ai.matching import extract_unmatched_names
|
|
||||||
from paperless.ai.matching import match_correspondents_by_name
|
|
||||||
from paperless.ai.matching import match_document_types_by_name
|
|
||||||
from paperless.ai.matching import match_storage_paths_by_name
|
|
||||||
from paperless.ai.matching import match_tags_by_name
|
|
||||||
from paperless.celery import app as celery_app
|
from paperless.celery import app as celery_app
|
||||||
from paperless.config import AIConfig
|
from paperless.config import AIConfig
|
||||||
from paperless.config import GeneralConfig
|
from paperless.config import GeneralConfig
|
||||||
@ -188,6 +181,13 @@ from paperless.db import GnuPG
|
|||||||
from paperless.serialisers import GroupSerializer
|
from paperless.serialisers import GroupSerializer
|
||||||
from paperless.serialisers import UserSerializer
|
from paperless.serialisers import UserSerializer
|
||||||
from paperless.views import StandardPagination
|
from paperless.views import StandardPagination
|
||||||
|
from paperless_ai.ai_classifier import get_ai_document_classification
|
||||||
|
from paperless_ai.chat import stream_chat_with_documents
|
||||||
|
from paperless_ai.matching import extract_unmatched_names
|
||||||
|
from paperless_ai.matching import match_correspondents_by_name
|
||||||
|
from paperless_ai.matching import match_document_types_by_name
|
||||||
|
from paperless_ai.matching import match_storage_paths_by_name
|
||||||
|
from paperless_ai.matching import match_tags_by_name
|
||||||
from paperless_mail.models import MailAccount
|
from paperless_mail.models import MailAccount
|
||||||
from paperless_mail.models import MailRule
|
from paperless_mail.models import MailRule
|
||||||
from paperless_mail.oauth import PaperlessMailOAuth2Manager
|
from paperless_mail.oauth import PaperlessMailOAuth2Manager
|
||||||
|
@ -6,11 +6,11 @@ from llama_index.core.base.llms.types import CompletionResponse
|
|||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from documents.permissions import get_objects_for_user_owner_aware
|
from documents.permissions import get_objects_for_user_owner_aware
|
||||||
from paperless.ai.client import AIClient
|
|
||||||
from paperless.ai.indexing import query_similar_documents
|
|
||||||
from paperless.config import AIConfig
|
from paperless.config import AIConfig
|
||||||
|
from paperless_ai.client import AIClient
|
||||||
|
from paperless_ai.indexing import query_similar_documents
|
||||||
|
|
||||||
logger = logging.getLogger("paperless.ai.rag_classifier")
|
logger = logging.getLogger("paperless_ai.rag_classifier")
|
||||||
|
|
||||||
|
|
||||||
def build_prompt_without_rag(document: Document) -> str:
|
def build_prompt_without_rag(document: Document) -> str:
|
@ -6,10 +6,10 @@ from llama_index.core.prompts import PromptTemplate
|
|||||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from paperless.ai.client import AIClient
|
from paperless_ai.client import AIClient
|
||||||
from paperless.ai.indexing import load_or_build_index
|
from paperless_ai.indexing import load_or_build_index
|
||||||
|
|
||||||
logger = logging.getLogger("paperless.ai.chat")
|
logger = logging.getLogger("paperless_ai.chat")
|
||||||
|
|
||||||
CHAT_PROMPT_TMPL = PromptTemplate(
|
CHAT_PROMPT_TMPL = PromptTemplate(
|
||||||
template="""Context information is below.
|
template="""Context information is below.
|
@ -6,7 +6,7 @@ from llama_index.llms.openai import OpenAI
|
|||||||
|
|
||||||
from paperless.config import AIConfig
|
from paperless.config import AIConfig
|
||||||
|
|
||||||
logger = logging.getLogger("paperless.ai.client")
|
logger = logging.getLogger("paperless_ai.client")
|
||||||
|
|
||||||
|
|
||||||
class AIClient:
|
class AIClient:
|
@ -17,11 +17,11 @@ from llama_index.core.storage.index_store import SimpleIndexStore
|
|||||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from paperless.ai.embedding import build_llm_index_text
|
from paperless_ai.embedding import build_llm_index_text
|
||||||
from paperless.ai.embedding import get_embedding_dim
|
from paperless_ai.embedding import get_embedding_dim
|
||||||
from paperless.ai.embedding import get_embedding_model
|
from paperless_ai.embedding import get_embedding_model
|
||||||
|
|
||||||
logger = logging.getLogger("paperless.ai.indexing")
|
logger = logging.getLogger("paperless_ai.indexing")
|
||||||
|
|
||||||
|
|
||||||
def get_or_create_storage_context(*, rebuild=False):
|
def get_or_create_storage_context(*, rebuild=False):
|
@ -12,7 +12,7 @@ from documents.permissions import get_objects_for_user_owner_aware
|
|||||||
|
|
||||||
MATCH_THRESHOLD = 0.8
|
MATCH_THRESHOLD = 0.8
|
||||||
|
|
||||||
logger = logging.getLogger("paperless.ai.matching")
|
logger = logging.getLogger("paperless_ai.matching")
|
||||||
|
|
||||||
|
|
||||||
def match_tags_by_name(names: list[str], user: User) -> list[Tag]:
|
def match_tags_by_name(names: list[str], user: User) -> list[Tag]:
|
0
src/paperless_ai/tests/__init__.py
Normal file
0
src/paperless_ai/tests/__init__.py
Normal file
@ -6,11 +6,11 @@ import pytest
|
|||||||
from django.test import override_settings
|
from django.test import override_settings
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from paperless.ai.ai_classifier import build_prompt_with_rag
|
from paperless_ai.ai_classifier import build_prompt_with_rag
|
||||||
from paperless.ai.ai_classifier import build_prompt_without_rag
|
from paperless_ai.ai_classifier import build_prompt_without_rag
|
||||||
from paperless.ai.ai_classifier import get_ai_document_classification
|
from paperless_ai.ai_classifier import get_ai_document_classification
|
||||||
from paperless.ai.ai_classifier import get_context_for_document
|
from paperless_ai.ai_classifier import get_context_for_document
|
||||||
from paperless.ai.ai_classifier import parse_ai_response
|
from paperless_ai.ai_classifier import parse_ai_response
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -49,7 +49,7 @@ def mock_document():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
@patch("paperless.ai.client.AIClient.run_llm_query")
|
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||||
@override_settings(
|
@override_settings(
|
||||||
LLM_BACKEND="ollama",
|
LLM_BACKEND="ollama",
|
||||||
LLM_MODEL="some_model",
|
LLM_MODEL="some_model",
|
||||||
@ -77,7 +77,7 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
@patch("paperless.ai.client.AIClient.run_llm_query")
|
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||||
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
|
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
|
||||||
mock_run_llm_query.side_effect = Exception("LLM query failed")
|
mock_run_llm_query.side_effect = Exception("LLM query failed")
|
||||||
|
|
||||||
@ -96,8 +96,8 @@ def test_parse_llm_classification_response_invalid_json():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
@patch("paperless.ai.client.AIClient.run_llm_query")
|
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||||
@patch("paperless.ai.ai_classifier.build_prompt_with_rag")
|
@patch("paperless_ai.ai_classifier.build_prompt_with_rag")
|
||||||
@override_settings(
|
@override_settings(
|
||||||
LLM_EMBEDDING_BACKEND="huggingface",
|
LLM_EMBEDDING_BACKEND="huggingface",
|
||||||
LLM_EMBEDDING_MODEL="some_model",
|
LLM_EMBEDDING_MODEL="some_model",
|
||||||
@ -116,8 +116,8 @@ def test_use_rag_if_configured(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
@patch("paperless.ai.client.AIClient.run_llm_query")
|
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||||
@patch("paperless.ai.ai_classifier.build_prompt_without_rag")
|
@patch("paperless_ai.ai_classifier.build_prompt_without_rag")
|
||||||
@patch("paperless.config.AIConfig")
|
@patch("paperless.config.AIConfig")
|
||||||
@override_settings(
|
@override_settings(
|
||||||
LLM_BACKEND="ollama",
|
LLM_BACKEND="ollama",
|
||||||
@ -144,7 +144,7 @@ def test_use_without_rag_if_not_configured(
|
|||||||
)
|
)
|
||||||
def test_prompt_with_without_rag(mock_document):
|
def test_prompt_with_without_rag(mock_document):
|
||||||
with patch(
|
with patch(
|
||||||
"paperless.ai.ai_classifier.get_context_for_document",
|
"paperless_ai.ai_classifier.get_context_for_document",
|
||||||
return_value="Context from similar documents",
|
return_value="Context from similar documents",
|
||||||
):
|
):
|
||||||
prompt = build_prompt_without_rag(mock_document)
|
prompt = build_prompt_without_rag(mock_document)
|
||||||
@ -174,7 +174,7 @@ def mock_similar_documents():
|
|||||||
return [doc1, doc2, doc3]
|
return [doc1, doc2, doc3]
|
||||||
|
|
||||||
|
|
||||||
@patch("paperless.ai.ai_classifier.query_similar_documents")
|
@patch("paperless_ai.ai_classifier.query_similar_documents")
|
||||||
def test_get_context_for_document(
|
def test_get_context_for_document(
|
||||||
mock_query_similar_documents,
|
mock_query_similar_documents,
|
||||||
mock_document,
|
mock_document,
|
||||||
@ -193,6 +193,6 @@ def test_get_context_for_document(
|
|||||||
|
|
||||||
|
|
||||||
def test_get_context_for_document_no_similar_docs(mock_document):
|
def test_get_context_for_document_no_similar_docs(mock_document):
|
||||||
with patch("paperless.ai.ai_classifier.query_similar_documents", return_value=[]):
|
with patch("paperless_ai.ai_classifier.query_similar_documents", return_value=[]):
|
||||||
result = get_context_for_document(mock_document)
|
result = get_context_for_document(mock_document)
|
||||||
assert result == ""
|
assert result == ""
|
@ -7,7 +7,7 @@ from django.utils import timezone
|
|||||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from paperless.ai import indexing
|
from paperless_ai import indexing
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -29,7 +29,7 @@ def real_document(db):
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_embed_model():
|
def mock_embed_model():
|
||||||
with patch("paperless.ai.indexing.get_embedding_model") as mock:
|
with patch("paperless_ai.indexing.get_embedding_model") as mock:
|
||||||
mock.return_value = FakeEmbedding()
|
mock.return_value = FakeEmbedding()
|
||||||
yield mock
|
yield mock
|
||||||
|
|
||||||
@ -112,7 +112,7 @@ def test_update_llm_index_partial_update(
|
|||||||
mock_all.return_value = mock_queryset
|
mock_all.return_value = mock_queryset
|
||||||
|
|
||||||
# assert logs "Updating LLM index with %d new nodes and removing %d old nodes."
|
# assert logs "Updating LLM index with %d new nodes and removing %d old nodes."
|
||||||
with patch("paperless.ai.indexing.logger") as mock_logger:
|
with patch("paperless_ai.indexing.logger") as mock_logger:
|
||||||
indexing.update_llm_index(rebuild=False)
|
indexing.update_llm_index(rebuild=False)
|
||||||
mock_logger.info.assert_called_once_with(
|
mock_logger.info.assert_called_once_with(
|
||||||
"Updating %d nodes in LLM index.",
|
"Updating %d nodes in LLM index.",
|
||||||
@ -139,15 +139,15 @@ def test_load_or_build_index_builds_when_nodes_given(
|
|||||||
real_document,
|
real_document,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
"paperless.ai.indexing.load_index_from_storage",
|
"paperless_ai.indexing.load_index_from_storage",
|
||||||
side_effect=ValueError("Index not found"),
|
side_effect=ValueError("Index not found"),
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
"paperless.ai.indexing.VectorStoreIndex",
|
"paperless_ai.indexing.VectorStoreIndex",
|
||||||
return_value=MagicMock(),
|
return_value=MagicMock(),
|
||||||
) as mock_index_cls:
|
) as mock_index_cls:
|
||||||
with patch(
|
with patch(
|
||||||
"paperless.ai.indexing.get_or_create_storage_context",
|
"paperless_ai.indexing.get_or_create_storage_context",
|
||||||
return_value=MagicMock(),
|
return_value=MagicMock(),
|
||||||
) as mock_storage:
|
) as mock_storage:
|
||||||
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||||
@ -161,7 +161,7 @@ def test_load_or_build_index_raises_exception_when_no_nodes(
|
|||||||
temp_llm_index_dir,
|
temp_llm_index_dir,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
"paperless.ai.indexing.load_index_from_storage",
|
"paperless_ai.indexing.load_index_from_storage",
|
||||||
side_effect=ValueError("Index not found"),
|
side_effect=ValueError("Index not found"),
|
||||||
):
|
):
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
@ -207,7 +207,7 @@ def test_update_llm_index_no_documents(
|
|||||||
mock_all.return_value = mock_queryset
|
mock_all.return_value = mock_queryset
|
||||||
|
|
||||||
# check log message
|
# check log message
|
||||||
with patch("paperless.ai.indexing.logger") as mock_logger:
|
with patch("paperless_ai.indexing.logger") as mock_logger:
|
||||||
indexing.update_llm_index(rebuild=True)
|
indexing.update_llm_index(rebuild=True)
|
||||||
mock_logger.warning.assert_called_once_with(
|
mock_logger.warning.assert_called_once_with(
|
||||||
"No documents found to index.",
|
"No documents found to index.",
|
||||||
@ -223,10 +223,10 @@ def test_query_similar_documents(
|
|||||||
real_document,
|
real_document,
|
||||||
):
|
):
|
||||||
with (
|
with (
|
||||||
patch("paperless.ai.indexing.get_or_create_storage_context") as mock_storage,
|
patch("paperless_ai.indexing.get_or_create_storage_context") as mock_storage,
|
||||||
patch("paperless.ai.indexing.load_or_build_index") as mock_load_or_build_index,
|
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
|
||||||
patch("paperless.ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
|
patch("paperless_ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
|
||||||
patch("paperless.ai.indexing.Document.objects.filter") as mock_filter,
|
patch("paperless_ai.indexing.Document.objects.filter") as mock_filter,
|
||||||
):
|
):
|
||||||
mock_storage.return_value = MagicMock()
|
mock_storage.return_value = MagicMock()
|
||||||
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||||
@ -251,7 +251,7 @@ def test_query_similar_documents(
|
|||||||
result = indexing.query_similar_documents(real_document, top_k=3)
|
result = indexing.query_similar_documents(real_document, top_k=3)
|
||||||
|
|
||||||
mock_load_or_build_index.assert_called_once()
|
mock_load_or_build_index.assert_called_once()
|
||||||
mock_retriever_cls.assert_called_once_with(index=mock_index, similarity_top_k=3)
|
mock_retriever_cls.assert_called_once()
|
||||||
mock_retriever.retrieve.assert_called_once_with(
|
mock_retriever.retrieve.assert_called_once_with(
|
||||||
"Test Document\nThis is some test content.",
|
"Test Document\nThis is some test content.",
|
||||||
)
|
)
|
@ -5,7 +5,7 @@ import pytest
|
|||||||
from llama_index.core import VectorStoreIndex
|
from llama_index.core import VectorStoreIndex
|
||||||
from llama_index.core.schema import TextNode
|
from llama_index.core.schema import TextNode
|
||||||
|
|
||||||
from paperless.ai.chat import stream_chat_with_documents
|
from paperless_ai.chat import stream_chat_with_documents
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@ -44,10 +44,10 @@ def mock_document():
|
|||||||
|
|
||||||
def test_stream_chat_with_one_document_full_content(mock_document):
|
def test_stream_chat_with_one_document_full_content(mock_document):
|
||||||
with (
|
with (
|
||||||
patch("paperless.ai.chat.AIClient") as mock_client_cls,
|
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||||
patch("paperless.ai.chat.load_or_build_index") as mock_load_index,
|
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||||
patch(
|
patch(
|
||||||
"paperless.ai.chat.RetrieverQueryEngine.from_args",
|
"paperless_ai.chat.RetrieverQueryEngine.from_args",
|
||||||
) as mock_query_engine_cls,
|
) as mock_query_engine_cls,
|
||||||
):
|
):
|
||||||
mock_client = MagicMock()
|
mock_client = MagicMock()
|
||||||
@ -75,10 +75,10 @@ def test_stream_chat_with_one_document_full_content(mock_document):
|
|||||||
|
|
||||||
def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
|
def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
|
||||||
with (
|
with (
|
||||||
patch("paperless.ai.chat.AIClient") as mock_client_cls,
|
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||||
patch("paperless.ai.chat.load_or_build_index") as mock_load_index,
|
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||||
patch(
|
patch(
|
||||||
"paperless.ai.chat.RetrieverQueryEngine.from_args",
|
"paperless_ai.chat.RetrieverQueryEngine.from_args",
|
||||||
) as mock_query_engine_cls,
|
) as mock_query_engine_cls,
|
||||||
patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever,
|
patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever,
|
||||||
):
|
):
|
||||||
@ -125,8 +125,8 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
|
|||||||
|
|
||||||
def test_stream_chat_no_matching_nodes():
|
def test_stream_chat_no_matching_nodes():
|
||||||
with (
|
with (
|
||||||
patch("paperless.ai.chat.AIClient") as mock_client_cls,
|
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||||
patch("paperless.ai.chat.load_or_build_index") as mock_load_index,
|
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||||
):
|
):
|
||||||
mock_client = MagicMock()
|
mock_client = MagicMock()
|
||||||
mock_client_cls.return_value = mock_client
|
mock_client_cls.return_value = mock_client
|
@ -4,12 +4,12 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
from llama_index.core.llms import ChatMessage
|
from llama_index.core.llms import ChatMessage
|
||||||
|
|
||||||
from paperless.ai.client import AIClient
|
from paperless_ai.client import AIClient
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_ai_config():
|
def mock_ai_config():
|
||||||
with patch("paperless.ai.client.AIConfig") as MockAIConfig:
|
with patch("paperless_ai.client.AIConfig") as MockAIConfig:
|
||||||
mock_config = MagicMock()
|
mock_config = MagicMock()
|
||||||
MockAIConfig.return_value = mock_config
|
MockAIConfig.return_value = mock_config
|
||||||
yield mock_config
|
yield mock_config
|
||||||
@ -17,13 +17,13 @@ def mock_ai_config():
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_ollama_llm():
|
def mock_ollama_llm():
|
||||||
with patch("paperless.ai.client.Ollama") as MockOllama:
|
with patch("paperless_ai.client.Ollama") as MockOllama:
|
||||||
yield MockOllama
|
yield MockOllama
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_openai_llm():
|
def mock_openai_llm():
|
||||||
with patch("paperless.ai.client.OpenAI") as MockOpenAI:
|
with patch("paperless_ai.client.OpenAI") as MockOpenAI:
|
||||||
yield MockOpenAI
|
yield MockOpenAI
|
||||||
|
|
||||||
|
|
@ -4,15 +4,15 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from paperless.ai.embedding import build_llm_index_text
|
|
||||||
from paperless.ai.embedding import get_embedding_dim
|
|
||||||
from paperless.ai.embedding import get_embedding_model
|
|
||||||
from paperless.models import LLMEmbeddingBackend
|
from paperless.models import LLMEmbeddingBackend
|
||||||
|
from paperless_ai.embedding import build_llm_index_text
|
||||||
|
from paperless_ai.embedding import get_embedding_dim
|
||||||
|
from paperless_ai.embedding import get_embedding_model
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_ai_config():
|
def mock_ai_config():
|
||||||
with patch("paperless.ai.embedding.AIConfig") as MockAIConfig:
|
with patch("paperless_ai.embedding.AIConfig") as MockAIConfig:
|
||||||
yield MockAIConfig
|
yield MockAIConfig
|
||||||
|
|
||||||
|
|
||||||
@ -56,7 +56,7 @@ def test_get_embedding_model_openai(mock_ai_config):
|
|||||||
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
|
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
|
||||||
mock_ai_config.return_value.llm_api_key = "test_api_key"
|
mock_ai_config.return_value.llm_api_key = "test_api_key"
|
||||||
|
|
||||||
with patch("paperless.ai.embedding.OpenAIEmbedding") as MockOpenAIEmbedding:
|
with patch("paperless_ai.embedding.OpenAIEmbedding") as MockOpenAIEmbedding:
|
||||||
model = get_embedding_model()
|
model = get_embedding_model()
|
||||||
MockOpenAIEmbedding.assert_called_once_with(
|
MockOpenAIEmbedding.assert_called_once_with(
|
||||||
model="text-embedding-3-small",
|
model="text-embedding-3-small",
|
||||||
@ -72,7 +72,7 @@ def test_get_embedding_model_huggingface(mock_ai_config):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"paperless.ai.embedding.HuggingFaceEmbedding",
|
"paperless_ai.embedding.HuggingFaceEmbedding",
|
||||||
) as MockHuggingFaceEmbedding:
|
) as MockHuggingFaceEmbedding:
|
||||||
model = get_embedding_model()
|
model = get_embedding_model()
|
||||||
MockHuggingFaceEmbedding.assert_called_once_with(
|
MockHuggingFaceEmbedding.assert_called_once_with(
|
@ -6,11 +6,11 @@ from documents.models import Correspondent
|
|||||||
from documents.models import DocumentType
|
from documents.models import DocumentType
|
||||||
from documents.models import StoragePath
|
from documents.models import StoragePath
|
||||||
from documents.models import Tag
|
from documents.models import Tag
|
||||||
from paperless.ai.matching import extract_unmatched_names
|
from paperless_ai.matching import extract_unmatched_names
|
||||||
from paperless.ai.matching import match_correspondents_by_name
|
from paperless_ai.matching import match_correspondents_by_name
|
||||||
from paperless.ai.matching import match_document_types_by_name
|
from paperless_ai.matching import match_document_types_by_name
|
||||||
from paperless.ai.matching import match_storage_paths_by_name
|
from paperless_ai.matching import match_storage_paths_by_name
|
||||||
from paperless.ai.matching import match_tags_by_name
|
from paperless_ai.matching import match_tags_by_name
|
||||||
|
|
||||||
|
|
||||||
class TestAIMatching(TestCase):
|
class TestAIMatching(TestCase):
|
||||||
@ -31,7 +31,7 @@ class TestAIMatching(TestCase):
|
|||||||
self.storage_path1 = StoragePath.objects.create(name="Test Storage Path 1")
|
self.storage_path1 = StoragePath.objects.create(name="Test Storage Path 1")
|
||||||
self.storage_path2 = StoragePath.objects.create(name="Test Storage Path 2")
|
self.storage_path2 = StoragePath.objects.create(name="Test Storage Path 2")
|
||||||
|
|
||||||
@patch("paperless.ai.matching.get_objects_for_user_owner_aware")
|
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||||
def test_match_tags_by_name(self, mock_get_objects):
|
def test_match_tags_by_name(self, mock_get_objects):
|
||||||
mock_get_objects.return_value = Tag.objects.all()
|
mock_get_objects.return_value = Tag.objects.all()
|
||||||
names = ["Test Tag 1", "Nonexistent Tag"]
|
names = ["Test Tag 1", "Nonexistent Tag"]
|
||||||
@ -39,7 +39,7 @@ class TestAIMatching(TestCase):
|
|||||||
self.assertEqual(len(result), 1)
|
self.assertEqual(len(result), 1)
|
||||||
self.assertEqual(result[0].name, "Test Tag 1")
|
self.assertEqual(result[0].name, "Test Tag 1")
|
||||||
|
|
||||||
@patch("paperless.ai.matching.get_objects_for_user_owner_aware")
|
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||||
def test_match_correspondents_by_name(self, mock_get_objects):
|
def test_match_correspondents_by_name(self, mock_get_objects):
|
||||||
mock_get_objects.return_value = Correspondent.objects.all()
|
mock_get_objects.return_value = Correspondent.objects.all()
|
||||||
names = ["Test Correspondent 1", "Nonexistent Correspondent"]
|
names = ["Test Correspondent 1", "Nonexistent Correspondent"]
|
||||||
@ -47,7 +47,7 @@ class TestAIMatching(TestCase):
|
|||||||
self.assertEqual(len(result), 1)
|
self.assertEqual(len(result), 1)
|
||||||
self.assertEqual(result[0].name, "Test Correspondent 1")
|
self.assertEqual(result[0].name, "Test Correspondent 1")
|
||||||
|
|
||||||
@patch("paperless.ai.matching.get_objects_for_user_owner_aware")
|
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||||
def test_match_document_types_by_name(self, mock_get_objects):
|
def test_match_document_types_by_name(self, mock_get_objects):
|
||||||
mock_get_objects.return_value = DocumentType.objects.all()
|
mock_get_objects.return_value = DocumentType.objects.all()
|
||||||
names = ["Test Document Type 1", "Nonexistent Document Type"]
|
names = ["Test Document Type 1", "Nonexistent Document Type"]
|
||||||
@ -55,7 +55,7 @@ class TestAIMatching(TestCase):
|
|||||||
self.assertEqual(len(result), 1)
|
self.assertEqual(len(result), 1)
|
||||||
self.assertEqual(result[0].name, "Test Document Type 1")
|
self.assertEqual(result[0].name, "Test Document Type 1")
|
||||||
|
|
||||||
@patch("paperless.ai.matching.get_objects_for_user_owner_aware")
|
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||||
def test_match_storage_paths_by_name(self, mock_get_objects):
|
def test_match_storage_paths_by_name(self, mock_get_objects):
|
||||||
mock_get_objects.return_value = StoragePath.objects.all()
|
mock_get_objects.return_value = StoragePath.objects.all()
|
||||||
names = ["Test Storage Path 1", "Nonexistent Storage Path"]
|
names = ["Test Storage Path 1", "Nonexistent Storage Path"]
|
||||||
@ -69,14 +69,14 @@ class TestAIMatching(TestCase):
|
|||||||
unmatched_names = extract_unmatched_names(llm_names, matched_objects)
|
unmatched_names = extract_unmatched_names(llm_names, matched_objects)
|
||||||
self.assertEqual(unmatched_names, ["Nonexistent Tag"])
|
self.assertEqual(unmatched_names, ["Nonexistent Tag"])
|
||||||
|
|
||||||
@patch("paperless.ai.matching.get_objects_for_user_owner_aware")
|
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||||
def test_match_tags_by_name_with_empty_names(self, mock_get_objects):
|
def test_match_tags_by_name_with_empty_names(self, mock_get_objects):
|
||||||
mock_get_objects.return_value = Tag.objects.all()
|
mock_get_objects.return_value = Tag.objects.all()
|
||||||
names = [None, "", " "]
|
names = [None, "", " "]
|
||||||
result = match_tags_by_name(names, user=None)
|
result = match_tags_by_name(names, user=None)
|
||||||
self.assertEqual(result, [])
|
self.assertEqual(result, [])
|
||||||
|
|
||||||
@patch("paperless.ai.matching.get_objects_for_user_owner_aware")
|
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||||
def test_match_tags_with_fuzzy_matching(self, mock_get_objects):
|
def test_match_tags_with_fuzzy_matching(self, mock_get_objects):
|
||||||
mock_get_objects.return_value = Tag.objects.all()
|
mock_get_objects.return_value = Tag.objects.all()
|
||||||
names = ["Test Taag 1", "Teest Tag 2"]
|
names = ["Test Taag 1", "Teest Tag 2"]
|
Loading…
x
Reference in New Issue
Block a user