Move ai to its own module

This commit is contained in:
shamoon 2025-04-28 22:25:02 -07:00
parent 62588e9819
commit c5f618d822
No known key found for this signature in database
16 changed files with 80 additions and 80 deletions

View File

@ -52,10 +52,10 @@ from documents.sanity_checker import SanityCheckFailedException
from documents.signals import document_updated
from documents.signals.handlers import cleanup_document_deletion
from documents.signals.handlers import run_workflows
from paperless.ai.indexing import llm_index_add_or_update_document
from paperless.ai.indexing import llm_index_remove_document
from paperless.ai.indexing import update_llm_index
from paperless.config import AIConfig
from paperless_ai.indexing import llm_index_add_or_update_document
from paperless_ai.indexing import llm_index_remove_document
from paperless_ai.indexing import update_llm_index
if settings.AUDIT_LOG_ENABLED:
from auditlog.models import LogEntry
@ -536,6 +536,6 @@ def remove_document_from_llm_index(document):
# TODO: schedule to run periodically
@shared_task
def rebuild_llm_index_task():
from paperless.ai.indexing import update_llm_index
from paperless_ai.indexing import update_llm_index
update_llm_index(rebuild=True)

View File

@ -174,13 +174,6 @@ from documents.tasks import sanity_check
from documents.tasks import train_classifier
from documents.templating.filepath import validate_filepath_template_and_render
from paperless import version
from paperless.ai.ai_classifier import get_ai_document_classification
from paperless.ai.chat import stream_chat_with_documents
from paperless.ai.matching import extract_unmatched_names
from paperless.ai.matching import match_correspondents_by_name
from paperless.ai.matching import match_document_types_by_name
from paperless.ai.matching import match_storage_paths_by_name
from paperless.ai.matching import match_tags_by_name
from paperless.celery import app as celery_app
from paperless.config import AIConfig
from paperless.config import GeneralConfig
@ -188,6 +181,13 @@ from paperless.db import GnuPG
from paperless.serialisers import GroupSerializer
from paperless.serialisers import UserSerializer
from paperless.views import StandardPagination
from paperless_ai.ai_classifier import get_ai_document_classification
from paperless_ai.chat import stream_chat_with_documents
from paperless_ai.matching import extract_unmatched_names
from paperless_ai.matching import match_correspondents_by_name
from paperless_ai.matching import match_document_types_by_name
from paperless_ai.matching import match_storage_paths_by_name
from paperless_ai.matching import match_tags_by_name
from paperless_mail.models import MailAccount
from paperless_mail.models import MailRule
from paperless_mail.oauth import PaperlessMailOAuth2Manager

View File

@ -6,11 +6,11 @@ from llama_index.core.base.llms.types import CompletionResponse
from documents.models import Document
from documents.permissions import get_objects_for_user_owner_aware
from paperless.ai.client import AIClient
from paperless.ai.indexing import query_similar_documents
from paperless.config import AIConfig
from paperless_ai.client import AIClient
from paperless_ai.indexing import query_similar_documents
logger = logging.getLogger("paperless.ai.rag_classifier")
logger = logging.getLogger("paperless_ai.rag_classifier")
def build_prompt_without_rag(document: Document) -> str:

View File

@ -6,10 +6,10 @@ from llama_index.core.prompts import PromptTemplate
from llama_index.core.query_engine import RetrieverQueryEngine
from documents.models import Document
from paperless.ai.client import AIClient
from paperless.ai.indexing import load_or_build_index
from paperless_ai.client import AIClient
from paperless_ai.indexing import load_or_build_index
logger = logging.getLogger("paperless.ai.chat")
logger = logging.getLogger("paperless_ai.chat")
CHAT_PROMPT_TMPL = PromptTemplate(
template="""Context information is below.

View File

@ -6,7 +6,7 @@ from llama_index.llms.openai import OpenAI
from paperless.config import AIConfig
logger = logging.getLogger("paperless.ai.client")
logger = logging.getLogger("paperless_ai.client")
class AIClient:

View File

@ -17,11 +17,11 @@ from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.vector_stores.faiss import FaissVectorStore
from documents.models import Document
from paperless.ai.embedding import build_llm_index_text
from paperless.ai.embedding import get_embedding_dim
from paperless.ai.embedding import get_embedding_model
from paperless_ai.embedding import build_llm_index_text
from paperless_ai.embedding import get_embedding_dim
from paperless_ai.embedding import get_embedding_model
logger = logging.getLogger("paperless.ai.indexing")
logger = logging.getLogger("paperless_ai.indexing")
def get_or_create_storage_context(*, rebuild=False):

View File

@ -12,7 +12,7 @@ from documents.permissions import get_objects_for_user_owner_aware
MATCH_THRESHOLD = 0.8
logger = logging.getLogger("paperless.ai.matching")
logger = logging.getLogger("paperless_ai.matching")
def match_tags_by_name(names: list[str], user: User) -> list[Tag]:

View File

View File

@ -6,11 +6,11 @@ import pytest
from django.test import override_settings
from documents.models import Document
from paperless.ai.ai_classifier import build_prompt_with_rag
from paperless.ai.ai_classifier import build_prompt_without_rag
from paperless.ai.ai_classifier import get_ai_document_classification
from paperless.ai.ai_classifier import get_context_for_document
from paperless.ai.ai_classifier import parse_ai_response
from paperless_ai.ai_classifier import build_prompt_with_rag
from paperless_ai.ai_classifier import build_prompt_without_rag
from paperless_ai.ai_classifier import get_ai_document_classification
from paperless_ai.ai_classifier import get_context_for_document
from paperless_ai.ai_classifier import parse_ai_response
@pytest.fixture
@ -49,7 +49,7 @@ def mock_document():
@pytest.mark.django_db
@patch("paperless.ai.client.AIClient.run_llm_query")
@patch("paperless_ai.client.AIClient.run_llm_query")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
@ -77,7 +77,7 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
@pytest.mark.django_db
@patch("paperless.ai.client.AIClient.run_llm_query")
@patch("paperless_ai.client.AIClient.run_llm_query")
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
mock_run_llm_query.side_effect = Exception("LLM query failed")
@ -96,8 +96,8 @@ def test_parse_llm_classification_response_invalid_json():
@pytest.mark.django_db
@patch("paperless.ai.client.AIClient.run_llm_query")
@patch("paperless.ai.ai_classifier.build_prompt_with_rag")
@patch("paperless_ai.client.AIClient.run_llm_query")
@patch("paperless_ai.ai_classifier.build_prompt_with_rag")
@override_settings(
LLM_EMBEDDING_BACKEND="huggingface",
LLM_EMBEDDING_MODEL="some_model",
@ -116,8 +116,8 @@ def test_use_rag_if_configured(
@pytest.mark.django_db
@patch("paperless.ai.client.AIClient.run_llm_query")
@patch("paperless.ai.ai_classifier.build_prompt_without_rag")
@patch("paperless_ai.client.AIClient.run_llm_query")
@patch("paperless_ai.ai_classifier.build_prompt_without_rag")
@patch("paperless.config.AIConfig")
@override_settings(
LLM_BACKEND="ollama",
@ -144,7 +144,7 @@ def test_use_without_rag_if_not_configured(
)
def test_prompt_with_without_rag(mock_document):
with patch(
"paperless.ai.ai_classifier.get_context_for_document",
"paperless_ai.ai_classifier.get_context_for_document",
return_value="Context from similar documents",
):
prompt = build_prompt_without_rag(mock_document)
@ -174,7 +174,7 @@ def mock_similar_documents():
return [doc1, doc2, doc3]
@patch("paperless.ai.ai_classifier.query_similar_documents")
@patch("paperless_ai.ai_classifier.query_similar_documents")
def test_get_context_for_document(
mock_query_similar_documents,
mock_document,
@ -193,6 +193,6 @@ def test_get_context_for_document(
def test_get_context_for_document_no_similar_docs(mock_document):
with patch("paperless.ai.ai_classifier.query_similar_documents", return_value=[]):
with patch("paperless_ai.ai_classifier.query_similar_documents", return_value=[]):
result = get_context_for_document(mock_document)
assert result == ""

View File

@ -7,7 +7,7 @@ from django.utils import timezone
from llama_index.core.base.embeddings.base import BaseEmbedding
from documents.models import Document
from paperless.ai import indexing
from paperless_ai import indexing
@pytest.fixture
@ -29,7 +29,7 @@ def real_document(db):
@pytest.fixture
def mock_embed_model():
with patch("paperless.ai.indexing.get_embedding_model") as mock:
with patch("paperless_ai.indexing.get_embedding_model") as mock:
mock.return_value = FakeEmbedding()
yield mock
@ -112,7 +112,7 @@ def test_update_llm_index_partial_update(
mock_all.return_value = mock_queryset
# assert logs "Updating LLM index with %d new nodes and removing %d old nodes."
with patch("paperless.ai.indexing.logger") as mock_logger:
with patch("paperless_ai.indexing.logger") as mock_logger:
indexing.update_llm_index(rebuild=False)
mock_logger.info.assert_called_once_with(
"Updating %d nodes in LLM index.",
@ -139,15 +139,15 @@ def test_load_or_build_index_builds_when_nodes_given(
real_document,
):
with patch(
"paperless.ai.indexing.load_index_from_storage",
"paperless_ai.indexing.load_index_from_storage",
side_effect=ValueError("Index not found"),
):
with patch(
"paperless.ai.indexing.VectorStoreIndex",
"paperless_ai.indexing.VectorStoreIndex",
return_value=MagicMock(),
) as mock_index_cls:
with patch(
"paperless.ai.indexing.get_or_create_storage_context",
"paperless_ai.indexing.get_or_create_storage_context",
return_value=MagicMock(),
) as mock_storage:
mock_storage.return_value.persist_dir = temp_llm_index_dir
@ -161,7 +161,7 @@ def test_load_or_build_index_raises_exception_when_no_nodes(
temp_llm_index_dir,
):
with patch(
"paperless.ai.indexing.load_index_from_storage",
"paperless_ai.indexing.load_index_from_storage",
side_effect=ValueError("Index not found"),
):
with pytest.raises(Exception):
@ -207,7 +207,7 @@ def test_update_llm_index_no_documents(
mock_all.return_value = mock_queryset
# check log message
with patch("paperless.ai.indexing.logger") as mock_logger:
with patch("paperless_ai.indexing.logger") as mock_logger:
indexing.update_llm_index(rebuild=True)
mock_logger.warning.assert_called_once_with(
"No documents found to index.",
@ -223,10 +223,10 @@ def test_query_similar_documents(
real_document,
):
with (
patch("paperless.ai.indexing.get_or_create_storage_context") as mock_storage,
patch("paperless.ai.indexing.load_or_build_index") as mock_load_or_build_index,
patch("paperless.ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
patch("paperless.ai.indexing.Document.objects.filter") as mock_filter,
patch("paperless_ai.indexing.get_or_create_storage_context") as mock_storage,
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
patch("paperless_ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
patch("paperless_ai.indexing.Document.objects.filter") as mock_filter,
):
mock_storage.return_value = MagicMock()
mock_storage.return_value.persist_dir = temp_llm_index_dir
@ -251,7 +251,7 @@ def test_query_similar_documents(
result = indexing.query_similar_documents(real_document, top_k=3)
mock_load_or_build_index.assert_called_once()
mock_retriever_cls.assert_called_once_with(index=mock_index, similarity_top_k=3)
mock_retriever_cls.assert_called_once()
mock_retriever.retrieve.assert_called_once_with(
"Test Document\nThis is some test content.",
)

View File

@ -5,7 +5,7 @@ import pytest
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import TextNode
from paperless.ai.chat import stream_chat_with_documents
from paperless_ai.chat import stream_chat_with_documents
@pytest.fixture(autouse=True)
@ -44,10 +44,10 @@ def mock_document():
def test_stream_chat_with_one_document_full_content(mock_document):
with (
patch("paperless.ai.chat.AIClient") as mock_client_cls,
patch("paperless.ai.chat.load_or_build_index") as mock_load_index,
patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
patch(
"paperless.ai.chat.RetrieverQueryEngine.from_args",
"paperless_ai.chat.RetrieverQueryEngine.from_args",
) as mock_query_engine_cls,
):
mock_client = MagicMock()
@ -75,10 +75,10 @@ def test_stream_chat_with_one_document_full_content(mock_document):
def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
with (
patch("paperless.ai.chat.AIClient") as mock_client_cls,
patch("paperless.ai.chat.load_or_build_index") as mock_load_index,
patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
patch(
"paperless.ai.chat.RetrieverQueryEngine.from_args",
"paperless_ai.chat.RetrieverQueryEngine.from_args",
) as mock_query_engine_cls,
patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever,
):
@ -125,8 +125,8 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
def test_stream_chat_no_matching_nodes():
with (
patch("paperless.ai.chat.AIClient") as mock_client_cls,
patch("paperless.ai.chat.load_or_build_index") as mock_load_index,
patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
):
mock_client = MagicMock()
mock_client_cls.return_value = mock_client

View File

@ -4,12 +4,12 @@ from unittest.mock import patch
import pytest
from llama_index.core.llms import ChatMessage
from paperless.ai.client import AIClient
from paperless_ai.client import AIClient
@pytest.fixture
def mock_ai_config():
with patch("paperless.ai.client.AIConfig") as MockAIConfig:
with patch("paperless_ai.client.AIConfig") as MockAIConfig:
mock_config = MagicMock()
MockAIConfig.return_value = mock_config
yield mock_config
@ -17,13 +17,13 @@ def mock_ai_config():
@pytest.fixture
def mock_ollama_llm():
with patch("paperless.ai.client.Ollama") as MockOllama:
with patch("paperless_ai.client.Ollama") as MockOllama:
yield MockOllama
@pytest.fixture
def mock_openai_llm():
with patch("paperless.ai.client.OpenAI") as MockOpenAI:
with patch("paperless_ai.client.OpenAI") as MockOpenAI:
yield MockOpenAI

View File

@ -4,15 +4,15 @@ from unittest.mock import patch
import pytest
from documents.models import Document
from paperless.ai.embedding import build_llm_index_text
from paperless.ai.embedding import get_embedding_dim
from paperless.ai.embedding import get_embedding_model
from paperless.models import LLMEmbeddingBackend
from paperless_ai.embedding import build_llm_index_text
from paperless_ai.embedding import get_embedding_dim
from paperless_ai.embedding import get_embedding_model
@pytest.fixture
def mock_ai_config():
with patch("paperless.ai.embedding.AIConfig") as MockAIConfig:
with patch("paperless_ai.embedding.AIConfig") as MockAIConfig:
yield MockAIConfig
@ -56,7 +56,7 @@ def test_get_embedding_model_openai(mock_ai_config):
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
mock_ai_config.return_value.llm_api_key = "test_api_key"
with patch("paperless.ai.embedding.OpenAIEmbedding") as MockOpenAIEmbedding:
with patch("paperless_ai.embedding.OpenAIEmbedding") as MockOpenAIEmbedding:
model = get_embedding_model()
MockOpenAIEmbedding.assert_called_once_with(
model="text-embedding-3-small",
@ -72,7 +72,7 @@ def test_get_embedding_model_huggingface(mock_ai_config):
)
with patch(
"paperless.ai.embedding.HuggingFaceEmbedding",
"paperless_ai.embedding.HuggingFaceEmbedding",
) as MockHuggingFaceEmbedding:
model = get_embedding_model()
MockHuggingFaceEmbedding.assert_called_once_with(

View File

@ -6,11 +6,11 @@ from documents.models import Correspondent
from documents.models import DocumentType
from documents.models import StoragePath
from documents.models import Tag
from paperless.ai.matching import extract_unmatched_names
from paperless.ai.matching import match_correspondents_by_name
from paperless.ai.matching import match_document_types_by_name
from paperless.ai.matching import match_storage_paths_by_name
from paperless.ai.matching import match_tags_by_name
from paperless_ai.matching import extract_unmatched_names
from paperless_ai.matching import match_correspondents_by_name
from paperless_ai.matching import match_document_types_by_name
from paperless_ai.matching import match_storage_paths_by_name
from paperless_ai.matching import match_tags_by_name
class TestAIMatching(TestCase):
@ -31,7 +31,7 @@ class TestAIMatching(TestCase):
self.storage_path1 = StoragePath.objects.create(name="Test Storage Path 1")
self.storage_path2 = StoragePath.objects.create(name="Test Storage Path 2")
@patch("paperless.ai.matching.get_objects_for_user_owner_aware")
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_tags_by_name(self, mock_get_objects):
mock_get_objects.return_value = Tag.objects.all()
names = ["Test Tag 1", "Nonexistent Tag"]
@ -39,7 +39,7 @@ class TestAIMatching(TestCase):
self.assertEqual(len(result), 1)
self.assertEqual(result[0].name, "Test Tag 1")
@patch("paperless.ai.matching.get_objects_for_user_owner_aware")
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_correspondents_by_name(self, mock_get_objects):
mock_get_objects.return_value = Correspondent.objects.all()
names = ["Test Correspondent 1", "Nonexistent Correspondent"]
@ -47,7 +47,7 @@ class TestAIMatching(TestCase):
self.assertEqual(len(result), 1)
self.assertEqual(result[0].name, "Test Correspondent 1")
@patch("paperless.ai.matching.get_objects_for_user_owner_aware")
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_document_types_by_name(self, mock_get_objects):
mock_get_objects.return_value = DocumentType.objects.all()
names = ["Test Document Type 1", "Nonexistent Document Type"]
@ -55,7 +55,7 @@ class TestAIMatching(TestCase):
self.assertEqual(len(result), 1)
self.assertEqual(result[0].name, "Test Document Type 1")
@patch("paperless.ai.matching.get_objects_for_user_owner_aware")
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_storage_paths_by_name(self, mock_get_objects):
mock_get_objects.return_value = StoragePath.objects.all()
names = ["Test Storage Path 1", "Nonexistent Storage Path"]
@ -69,14 +69,14 @@ class TestAIMatching(TestCase):
unmatched_names = extract_unmatched_names(llm_names, matched_objects)
self.assertEqual(unmatched_names, ["Nonexistent Tag"])
@patch("paperless.ai.matching.get_objects_for_user_owner_aware")
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_tags_by_name_with_empty_names(self, mock_get_objects):
mock_get_objects.return_value = Tag.objects.all()
names = [None, "", " "]
result = match_tags_by_name(names, user=None)
self.assertEqual(result, [])
@patch("paperless.ai.matching.get_objects_for_user_owner_aware")
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_tags_with_fuzzy_matching(self, mock_get_objects):
mock_get_objects.return_value = Tag.objects.all()
names = ["Test Taag 1", "Teest Tag 2"]