mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-05-24 02:02:23 -04:00
Handle doc updates, refactor
This commit is contained in:
parent
2cafe4a2c0
commit
fd8ffa62b0
@ -11,6 +11,7 @@ class DocumentsConfig(AppConfig):
|
||||
from documents.signals import document_consumption_finished
|
||||
from documents.signals import document_updated
|
||||
from documents.signals.handlers import add_inbox_tags
|
||||
from documents.signals.handlers import add_or_update_document_in_llm_index
|
||||
from documents.signals.handlers import add_to_index
|
||||
from documents.signals.handlers import run_workflows_added
|
||||
from documents.signals.handlers import run_workflows_updated
|
||||
@ -26,6 +27,7 @@ class DocumentsConfig(AppConfig):
|
||||
document_consumption_finished.connect(set_storage_path)
|
||||
document_consumption_finished.connect(add_to_index)
|
||||
document_consumption_finished.connect(run_workflows_added)
|
||||
document_consumption_finished.connect(add_or_update_document_in_llm_index)
|
||||
document_updated.connect(run_workflows_updated)
|
||||
|
||||
import documents.schema # noqa: F401
|
||||
|
@ -46,6 +46,7 @@ from documents.models import WorkflowTrigger
|
||||
from documents.permissions import get_objects_for_user_owner_aware
|
||||
from documents.permissions import set_permissions_for_object
|
||||
from documents.templating.workflows import parse_w_workflow_placeholders
|
||||
from paperless.config import AIConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
@ -1403,3 +1404,26 @@ def task_failure_handler(
|
||||
task_instance.save()
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("Updating PaperlessTask failed")
|
||||
|
||||
|
||||
def add_or_update_document_in_llm_index(sender, document, **kwargs):
|
||||
"""
|
||||
Add or update a document in the LLM index when it is created or updated.
|
||||
"""
|
||||
ai_config = AIConfig()
|
||||
if ai_config.llm_index_enabled():
|
||||
from documents.tasks import update_document_in_llm_index
|
||||
|
||||
update_document_in_llm_index.delay(document)
|
||||
|
||||
|
||||
@receiver(models.signals.post_delete, sender=Document)
|
||||
def delete_document_from_llm_index(sender, instance: Document, **kwargs):
|
||||
"""
|
||||
Delete a document from the LLM index when it is deleted.
|
||||
"""
|
||||
ai_config = AIConfig()
|
||||
if ai_config.llm_index_enabled():
|
||||
from documents.tasks import remove_document_from_llm_index
|
||||
|
||||
remove_document_from_llm_index.delay(instance)
|
||||
|
@ -6,8 +6,6 @@ import uuid
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import faiss
|
||||
import llama_index.core.settings as llama_settings
|
||||
import tqdm
|
||||
from celery import Task
|
||||
from celery import shared_task
|
||||
@ -19,13 +17,6 @@ from django.db import transaction
|
||||
from django.db.models.signals import post_save
|
||||
from django.utils import timezone
|
||||
from filelock import FileLock
|
||||
from llama_index.core import Document as LlamaDocument
|
||||
from llama_index.core import StorageContext
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.node_parser import SimpleNodeParser
|
||||
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||
from llama_index.core.storage.index_store import SimpleIndexStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
from whoosh.writing import AsyncWriter
|
||||
|
||||
from documents import index
|
||||
@ -61,9 +52,10 @@ from documents.sanity_checker import SanityCheckFailedException
|
||||
from documents.signals import document_updated
|
||||
from documents.signals.handlers import cleanup_document_deletion
|
||||
from documents.signals.handlers import run_workflows
|
||||
from paperless.ai.embedding import build_llm_index_text
|
||||
from paperless.ai.embedding import get_embedding_dim
|
||||
from paperless.ai.embedding import get_embedding_model
|
||||
from paperless.ai.indexing import llm_index_add_or_update_document
|
||||
from paperless.ai.indexing import llm_index_remove_document
|
||||
from paperless.ai.indexing import rebuild_llm_index
|
||||
from paperless.config import AIConfig
|
||||
|
||||
if settings.AUDIT_LOG_ENABLED:
|
||||
from auditlog.models import LogEntry
|
||||
@ -251,6 +243,11 @@ def bulk_update_documents(document_ids):
|
||||
for doc in documents:
|
||||
index.update_document(writer, doc)
|
||||
|
||||
ai_config = AIConfig()
|
||||
if ai_config.llm_index_enabled():
|
||||
for doc in documents:
|
||||
llm_index_add_or_update_document()
|
||||
|
||||
|
||||
@shared_task
|
||||
def update_document_content_maybe_archive_file(document_id):
|
||||
@ -350,6 +347,10 @@ def update_document_content_maybe_archive_file(document_id):
|
||||
with index.open_index_writer() as writer:
|
||||
index.update_document(writer, document)
|
||||
|
||||
ai_config = AIConfig()
|
||||
if ai_config.llm_index_enabled:
|
||||
llm_index_add_or_update_document(document)
|
||||
|
||||
clear_document_caches(document.pk)
|
||||
|
||||
except Exception:
|
||||
@ -511,60 +512,25 @@ def check_scheduled_workflows():
|
||||
|
||||
|
||||
def llm_index_rebuild(*, progress_bar_disable=False, rebuild=False):
|
||||
if rebuild:
|
||||
shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
|
||||
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
embed_model = get_embedding_model()
|
||||
llama_settings.Settings.embed_model = embed_model
|
||||
|
||||
if rebuild or not settings.LLM_INDEX_DIR.exists():
|
||||
embedding_dim = get_embedding_dim()
|
||||
faiss_index = faiss.IndexFlatL2(embedding_dim)
|
||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||
else:
|
||||
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||
|
||||
docstore = SimpleDocumentStore()
|
||||
index_store = SimpleIndexStore()
|
||||
|
||||
storage_context = StorageContext.from_defaults(
|
||||
docstore=docstore,
|
||||
index_store=index_store,
|
||||
persist_dir=settings.LLM_INDEX_DIR,
|
||||
vector_store=vector_store,
|
||||
rebuild_llm_index(
|
||||
progress_bar_disable=progress_bar_disable,
|
||||
rebuild=rebuild,
|
||||
)
|
||||
|
||||
parser = SimpleNodeParser()
|
||||
nodes = []
|
||||
|
||||
for document in tqdm.tqdm(Document.objects.all(), disable=progress_bar_disable):
|
||||
if not document.content:
|
||||
continue
|
||||
@shared_task
|
||||
def update_document_in_llm_index(document):
|
||||
llm_index_add_or_update_document(document)
|
||||
|
||||
text = build_llm_index_text(document)
|
||||
metadata = {
|
||||
"document_id": document.id,
|
||||
"title": document.title,
|
||||
"tags": [t.name for t in document.tags.all()],
|
||||
"correspondent": document.correspondent.name
|
||||
if document.correspondent
|
||||
else None,
|
||||
"document_type": document.document_type.name
|
||||
if document.document_type
|
||||
else None,
|
||||
"created": document.created.isoformat() if document.created else None,
|
||||
"added": document.added.isoformat() if document.added else None,
|
||||
}
|
||||
|
||||
doc = LlamaDocument(text=text, metadata=metadata)
|
||||
doc_nodes = parser.get_nodes_from_documents([doc])
|
||||
nodes.extend(doc_nodes)
|
||||
@shared_task
|
||||
def remove_document_from_llm_index(document):
|
||||
llm_index_remove_document(document)
|
||||
|
||||
index = VectorStoreIndex(
|
||||
nodes=nodes,
|
||||
storage_context=storage_context,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
|
||||
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||
# TODO: schedule to run periodically
|
||||
@shared_task
|
||||
def rebuild_llm_index_task():
|
||||
from paperless.ai.indexing import rebuild_llm_index
|
||||
|
||||
rebuild_llm_index(rebuild=True)
|
||||
|
@ -7,7 +7,7 @@ from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
|
||||
from documents.models import Document
|
||||
from paperless.ai.client import AIClient
|
||||
from paperless.ai.indexing import load_index
|
||||
from paperless.ai.indexing import load_or_build_index
|
||||
|
||||
logger = logging.getLogger("paperless.ai.chat")
|
||||
|
||||
@ -24,7 +24,7 @@ CHAT_PROMPT_TMPL = PromptTemplate(
|
||||
|
||||
def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
client = AIClient()
|
||||
index = load_index()
|
||||
index = load_or_build_index()
|
||||
|
||||
doc_ids = [doc.pk for doc in documents]
|
||||
|
||||
|
@ -1,44 +1,209 @@
|
||||
import logging
|
||||
import shutil
|
||||
|
||||
import faiss
|
||||
import llama_index.core.settings as llama_settings
|
||||
import tqdm
|
||||
from django.conf import settings
|
||||
from llama_index.core import Document as LlamaDocument
|
||||
from llama_index.core import StorageContext
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core import load_index_from_storage
|
||||
from llama_index.core.node_parser import SimpleNodeParser
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
from llama_index.core.schema import BaseNode
|
||||
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||
from llama_index.core.storage.index_store import SimpleIndexStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from documents.models import Document
|
||||
from paperless.ai.embedding import build_llm_index_text
|
||||
from paperless.ai.embedding import get_embedding_dim
|
||||
from paperless.ai.embedding import get_embedding_model
|
||||
|
||||
logger = logging.getLogger("paperless.ai.indexing")
|
||||
|
||||
|
||||
def load_index() -> VectorStoreIndex:
|
||||
"""Loads the persisted LlamaIndex from disk."""
|
||||
def get_or_create_storage_context(*, rebuild=False):
|
||||
"""
|
||||
Loads or creates the StorageContext (vector store, docstore, index store).
|
||||
If rebuild=True, deletes and recreates everything.
|
||||
"""
|
||||
if rebuild:
|
||||
shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
|
||||
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if rebuild or not settings.LLM_INDEX_DIR.exists():
|
||||
embedding_dim = get_embedding_dim()
|
||||
faiss_index = faiss.IndexFlatL2(embedding_dim)
|
||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||
docstore = SimpleDocumentStore()
|
||||
index_store = SimpleIndexStore()
|
||||
else:
|
||||
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||
embed_model = get_embedding_model()
|
||||
docstore = SimpleDocumentStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||
index_store = SimpleIndexStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||
|
||||
llama_settings.Settings.embed_model = embed_model
|
||||
llama_settings.Settings.chunk_size = 512
|
||||
|
||||
storage_context = StorageContext.from_defaults(
|
||||
return StorageContext.from_defaults(
|
||||
docstore=docstore,
|
||||
index_store=index_store,
|
||||
vector_store=vector_store,
|
||||
persist_dir=settings.LLM_INDEX_DIR,
|
||||
)
|
||||
return load_index_from_storage(storage_context)
|
||||
|
||||
|
||||
def get_vector_store_index(storage_context, embed_model):
|
||||
"""
|
||||
Returns a VectorStoreIndex given a storage context and embed model.
|
||||
"""
|
||||
return VectorStoreIndex(
|
||||
storage_context=storage_context,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
|
||||
|
||||
def build_document_node(document) -> list[BaseNode]:
|
||||
"""
|
||||
Given a Document, returns parsed Nodes ready for indexing.
|
||||
"""
|
||||
if not document.content:
|
||||
return []
|
||||
|
||||
text = build_llm_index_text(document)
|
||||
metadata = {
|
||||
"document_id": document.id,
|
||||
"title": document.title,
|
||||
"tags": [t.name for t in document.tags.all()],
|
||||
"correspondent": document.correspondent.name
|
||||
if document.correspondent
|
||||
else None,
|
||||
"document_type": document.document_type.name
|
||||
if document.document_type
|
||||
else None,
|
||||
"created": document.created.isoformat() if document.created else None,
|
||||
"added": document.added.isoformat() if document.added else None,
|
||||
}
|
||||
doc = LlamaDocument(text=text, metadata=metadata)
|
||||
parser = SimpleNodeParser()
|
||||
return parser.get_nodes_from_documents([doc])
|
||||
|
||||
|
||||
def load_or_build_index(storage_context, embed_model, nodes=None):
|
||||
"""
|
||||
Load an existing VectorStoreIndex if present,
|
||||
or build a new one using provided nodes if storage is empty.
|
||||
"""
|
||||
try:
|
||||
return VectorStoreIndex(
|
||||
storage_context=storage_context,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
except ValueError as e:
|
||||
if "One of nodes, objects, or index_struct must be provided" in str(e):
|
||||
if not nodes:
|
||||
return None
|
||||
return VectorStoreIndex(
|
||||
nodes=nodes,
|
||||
storage_context=storage_context,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def remove_existing_document_nodes(document, index):
|
||||
"""
|
||||
Removes existing documents from docstore for a given document from the index.
|
||||
This is necessary because FAISS IndexFlatL2 is append-only.
|
||||
"""
|
||||
all_node_ids = list(index.docstore.docs.keys())
|
||||
existing_nodes = [
|
||||
node.node_id
|
||||
for node in index.docstore.get_nodes(all_node_ids)
|
||||
if node.metadata.get("document_id") == document.id
|
||||
]
|
||||
for node_id in existing_nodes:
|
||||
# Delete from docstore, FAISS IndexFlatL2 are append-only
|
||||
index.docstore.delete_document(node_id)
|
||||
|
||||
|
||||
def rebuild_llm_index(*, progress_bar_disable=False, rebuild=False):
|
||||
"""
|
||||
Rebuilds the LLM index from scratch.
|
||||
"""
|
||||
embed_model = get_embedding_model()
|
||||
llama_settings.Settings.embed_model = embed_model
|
||||
|
||||
storage_context = get_or_create_storage_context(rebuild=rebuild)
|
||||
|
||||
nodes = []
|
||||
|
||||
for document in tqdm.tqdm(Document.objects.all(), disable=progress_bar_disable):
|
||||
document_nodes = build_document_node(document)
|
||||
nodes.extend(document_nodes)
|
||||
|
||||
if not nodes:
|
||||
raise RuntimeError(
|
||||
"No nodes to index — check that documents are available and have content.",
|
||||
)
|
||||
|
||||
VectorStoreIndex(
|
||||
nodes=nodes,
|
||||
storage_context=storage_context,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||
|
||||
|
||||
def llm_index_add_or_update_document(document):
|
||||
"""
|
||||
Adds or updates a document in the LLM index.
|
||||
If the document already exists, it will be replaced.
|
||||
"""
|
||||
embed_model = get_embedding_model()
|
||||
llama_settings.Settings.embed_model = embed_model
|
||||
|
||||
storage_context = get_or_create_storage_context(rebuild=False)
|
||||
|
||||
new_nodes = build_document_node(document)
|
||||
|
||||
index = load_or_build_index(storage_context, embed_model, nodes=new_nodes)
|
||||
|
||||
if index is None:
|
||||
# Nothing to index
|
||||
return
|
||||
|
||||
# Remove old nodes
|
||||
remove_existing_document_nodes(document, index)
|
||||
|
||||
index.insert_nodes(new_nodes)
|
||||
|
||||
storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||
|
||||
|
||||
def llm_index_remove_document(document):
|
||||
embed_model = get_embedding_model()
|
||||
llama_settings.embed_model = embed_model
|
||||
|
||||
storage_context = get_or_create_storage_context(rebuild=False)
|
||||
|
||||
index = load_or_build_index(storage_context, embed_model)
|
||||
if index is None:
|
||||
return # Nothing to remove
|
||||
|
||||
# Remove old nodes
|
||||
remove_existing_document_nodes(document, index)
|
||||
|
||||
storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||
|
||||
|
||||
def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]:
|
||||
"""Runs a similarity query and returns top-k similar Document objects."""
|
||||
# Load the index
|
||||
index = load_index()
|
||||
"""
|
||||
Runs a similarity query and returns top-k similar Document objects.
|
||||
"""
|
||||
index = load_or_build_index()
|
||||
retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
|
||||
|
||||
# Build query from the document text
|
||||
query_text = (document.title or "") + "\n" + (document.content or "")
|
||||
|
||||
# Query
|
||||
results = retriever.retrieve(query_text)
|
||||
|
||||
# Each result.node.metadata["document_id"] should match our stored doc
|
||||
|
@ -199,3 +199,10 @@ class AIConfig(BaseConfig):
|
||||
self.llm_model = app_config.llm_model or settings.LLM_MODEL
|
||||
self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY
|
||||
self.llm_url = app_config.llm_url or settings.LLM_URL
|
||||
|
||||
def llm_index_enabled(self) -> bool:
|
||||
return (
|
||||
self.ai_enabled
|
||||
and self.llm_embedding_backend
|
||||
and self.llm_embedding_backend
|
||||
)
|
||||
|
@ -45,7 +45,7 @@ def mock_document():
|
||||
def test_stream_chat_with_one_document_full_content(mock_document):
|
||||
with (
|
||||
patch("paperless.ai.chat.AIClient") as mock_client_cls,
|
||||
patch("paperless.ai.chat.load_index") as mock_load_index,
|
||||
patch("paperless.ai.chat.load_or_build_index") as mock_load_index,
|
||||
patch(
|
||||
"paperless.ai.chat.RetrieverQueryEngine.from_args",
|
||||
) as mock_query_engine_cls,
|
||||
@ -76,7 +76,7 @@ def test_stream_chat_with_one_document_full_content(mock_document):
|
||||
def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
|
||||
with (
|
||||
patch("paperless.ai.chat.AIClient") as mock_client_cls,
|
||||
patch("paperless.ai.chat.load_index") as mock_load_index,
|
||||
patch("paperless.ai.chat.load_or_build_index") as mock_load_index,
|
||||
patch(
|
||||
"paperless.ai.chat.RetrieverQueryEngine.from_args",
|
||||
) as mock_query_engine_cls,
|
||||
@ -126,7 +126,7 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
|
||||
def test_stream_chat_no_matching_nodes():
|
||||
with (
|
||||
patch("paperless.ai.chat.AIClient") as mock_client_cls,
|
||||
patch("paperless.ai.chat.load_index") as mock_load_index,
|
||||
patch("paperless.ai.chat.load_or_build_index") as mock_load_index,
|
||||
):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
144
src/paperless/tests/test_ai_indexing.py
Normal file
144
src/paperless/tests/test_ai_indexing.py
Normal file
@ -0,0 +1,144 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from django.utils import timezone
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
|
||||
from documents.models import Document
|
||||
from paperless.ai import indexing
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_llm_index_dir(tmp_path):
|
||||
original_dir = indexing.settings.LLM_INDEX_DIR
|
||||
indexing.settings.LLM_INDEX_DIR = tmp_path
|
||||
yield tmp_path
|
||||
indexing.settings.LLM_INDEX_DIR = original_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_document(db):
|
||||
return Document.objects.create(
|
||||
title="Test Document",
|
||||
content="This is some test content.",
|
||||
added=timezone.now(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embed_model():
|
||||
"""Mocks the embedding model."""
|
||||
with patch("paperless.ai.indexing.get_embedding_model") as mock:
|
||||
mock.return_value = FakeEmbedding()
|
||||
yield mock
|
||||
|
||||
|
||||
class FakeEmbedding(BaseEmbedding):
|
||||
# TODO: maybe a better way to do this?
|
||||
def _aget_query_embedding(self, query: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def _get_query_embedding(self, query: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def _get_text_embedding(self, text: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def get_query_embedding_dim(self) -> int:
|
||||
return 384 # Match your real FAISS config
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_build_document_node(real_document):
|
||||
nodes = indexing.build_document_node(real_document)
|
||||
assert len(nodes) > 0
|
||||
assert nodes[0].metadata["document_id"] == real_document.id
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_rebuild_llm_index(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
):
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_all.return_value = [real_document]
|
||||
indexing.rebuild_llm_index(rebuild=True)
|
||||
|
||||
assert any(temp_llm_index_dir.glob("*.json"))
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_add_or_update_document_updates_existing_entry(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
):
|
||||
indexing.rebuild_llm_index(rebuild=True)
|
||||
indexing.llm_index_add_or_update_document(real_document)
|
||||
|
||||
assert any(temp_llm_index_dir.glob("*.json"))
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_remove_document_deletes_node_from_docstore(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
):
|
||||
indexing.rebuild_llm_index(rebuild=True)
|
||||
indexing.llm_index_add_or_update_document(real_document)
|
||||
indexing.llm_index_remove_document(real_document)
|
||||
|
||||
assert any(temp_llm_index_dir.glob("*.json"))
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_rebuild_llm_index_no_documents(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
):
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_all.return_value = []
|
||||
|
||||
with pytest.raises(RuntimeError, match="No nodes to index"):
|
||||
indexing.rebuild_llm_index(rebuild=True)
|
||||
|
||||
|
||||
def test_query_similar_documents(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
):
|
||||
with (
|
||||
patch("paperless.ai.indexing.load_or_build_index") as mock_load_or_build_index,
|
||||
patch("paperless.ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
|
||||
patch("paperless.ai.indexing.Document.objects.filter") as mock_filter,
|
||||
):
|
||||
mock_index = MagicMock()
|
||||
mock_load_or_build_index.return_value = mock_index
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever_cls.return_value = mock_retriever
|
||||
|
||||
mock_node1 = MagicMock()
|
||||
mock_node1.metadata = {"document_id": 1}
|
||||
|
||||
mock_node2 = MagicMock()
|
||||
mock_node2.metadata = {"document_id": 2}
|
||||
|
||||
mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
|
||||
|
||||
mock_filtered_docs = [MagicMock(pk=1), MagicMock(pk=2)]
|
||||
mock_filter.return_value = mock_filtered_docs
|
||||
|
||||
result = indexing.query_similar_documents(real_document, top_k=3)
|
||||
|
||||
mock_load_or_build_index.assert_called_once()
|
||||
mock_retriever_cls.assert_called_once_with(index=mock_index, similarity_top_k=3)
|
||||
mock_retriever.retrieve.assert_called_once_with(
|
||||
"Test Document\nThis is some test content.",
|
||||
)
|
||||
mock_filter.assert_called_once_with(pk__in=[1, 2])
|
||||
|
||||
assert result == mock_filtered_docs
|
@ -2,14 +2,11 @@ from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
|
||||
from documents.models import Document
|
||||
from paperless.ai.embedding import build_llm_index_text
|
||||
from paperless.ai.embedding import get_embedding_dim
|
||||
from paperless.ai.embedding import get_embedding_model
|
||||
from paperless.ai.indexing import load_index
|
||||
from paperless.ai.indexing import query_similar_documents
|
||||
from paperless.ai.rag import get_context_for_document
|
||||
from paperless.models import LLMEmbeddingBackend
|
||||
|
||||
@ -182,93 +179,3 @@ def test_build_llm_index_text(mock_document):
|
||||
assert "Notes: Note1,Note2" in result
|
||||
assert "Content:\n\nThis is the document content." in result
|
||||
assert "Custom Field - Field1: Value1\nCustom Field - Field2: Value2" in result
|
||||
|
||||
|
||||
# Indexing
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings(settings):
|
||||
settings.LLM_INDEX_DIR = "/fake/path"
|
||||
return settings
|
||||
|
||||
|
||||
class FakeEmbedding(BaseEmbedding):
|
||||
# TODO: gotta be a better way to do this
|
||||
def _aget_query_embedding(self, query: str) -> list[float]:
|
||||
return [0.1, 0.2, 0.3]
|
||||
|
||||
def _get_query_embedding(self, query: str) -> list[float]:
|
||||
return [0.1, 0.2, 0.3]
|
||||
|
||||
def _get_text_embedding(self, text: str) -> list[float]:
|
||||
return [0.1, 0.2, 0.3]
|
||||
|
||||
|
||||
def test_load_index(mock_settings):
|
||||
with (
|
||||
patch("paperless.ai.indexing.FaissVectorStore.from_persist_dir") as mock_faiss,
|
||||
patch("paperless.ai.indexing.get_embedding_model") as mock_get_embed_model,
|
||||
patch(
|
||||
"paperless.ai.indexing.StorageContext.from_defaults",
|
||||
) as mock_storage_context,
|
||||
patch("paperless.ai.indexing.load_index_from_storage") as mock_load_index,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_vector_store = MagicMock()
|
||||
mock_storage = MagicMock()
|
||||
mock_index = MagicMock()
|
||||
|
||||
mock_faiss.return_value = mock_vector_store
|
||||
mock_storage_context.return_value = mock_storage
|
||||
mock_load_index.return_value = mock_index
|
||||
mock_get_embed_model.return_value = FakeEmbedding()
|
||||
|
||||
# Act
|
||||
result = load_index()
|
||||
|
||||
# Assert
|
||||
mock_faiss.assert_called_once_with("/fake/path")
|
||||
mock_get_embed_model.assert_called_once()
|
||||
mock_storage_context.assert_called_once_with(
|
||||
vector_store=mock_vector_store,
|
||||
persist_dir="/fake/path",
|
||||
)
|
||||
mock_load_index.assert_called_once_with(mock_storage)
|
||||
assert result == mock_index
|
||||
|
||||
|
||||
def test_query_similar_documents(mock_document):
|
||||
with (
|
||||
patch("paperless.ai.indexing.load_index") as mock_load_index_func,
|
||||
patch("paperless.ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
|
||||
patch("paperless.ai.indexing.Document.objects.filter") as mock_filter,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_index = MagicMock()
|
||||
mock_load_index_func.return_value = mock_index
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever_cls.return_value = mock_retriever
|
||||
|
||||
mock_node1 = MagicMock()
|
||||
mock_node1.metadata = {"document_id": 1}
|
||||
|
||||
mock_node2 = MagicMock()
|
||||
mock_node2.metadata = {"document_id": 2}
|
||||
|
||||
mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
|
||||
|
||||
mock_filtered_docs = [MagicMock(pk=1), MagicMock(pk=2)]
|
||||
mock_filter.return_value = mock_filtered_docs
|
||||
|
||||
result = query_similar_documents(mock_document, top_k=3)
|
||||
|
||||
mock_load_index_func.assert_called_once()
|
||||
mock_retriever_cls.assert_called_once_with(index=mock_index, similarity_top_k=3)
|
||||
mock_retriever.retrieve.assert_called_once_with(
|
||||
"Test Title\nThis is the document content.",
|
||||
)
|
||||
mock_filter.assert_called_once_with(pk__in=[1, 2])
|
||||
|
||||
assert result == mock_filtered_docs
|
||||
|
Loading…
x
Reference in New Issue
Block a user