Handle doc updates, refactor

This commit is contained in:
shamoon 2025-04-27 01:24:00 -07:00
parent 2cafe4a2c0
commit fd8ffa62b0
No known key found for this signature in database
9 changed files with 389 additions and 174 deletions

View File

@ -11,6 +11,7 @@ class DocumentsConfig(AppConfig):
from documents.signals import document_consumption_finished
from documents.signals import document_updated
from documents.signals.handlers import add_inbox_tags
from documents.signals.handlers import add_or_update_document_in_llm_index
from documents.signals.handlers import add_to_index
from documents.signals.handlers import run_workflows_added
from documents.signals.handlers import run_workflows_updated
@ -26,6 +27,7 @@ class DocumentsConfig(AppConfig):
document_consumption_finished.connect(set_storage_path)
document_consumption_finished.connect(add_to_index)
document_consumption_finished.connect(run_workflows_added)
document_consumption_finished.connect(add_or_update_document_in_llm_index)
document_updated.connect(run_workflows_updated)
import documents.schema # noqa: F401

View File

@ -46,6 +46,7 @@ from documents.models import WorkflowTrigger
from documents.permissions import get_objects_for_user_owner_aware
from documents.permissions import set_permissions_for_object
from documents.templating.workflows import parse_w_workflow_placeholders
from paperless.config import AIConfig
if TYPE_CHECKING:
from pathlib import Path
@ -1403,3 +1404,26 @@ def task_failure_handler(
task_instance.save()
except Exception: # pragma: no cover
logger.exception("Updating PaperlessTask failed")
def add_or_update_document_in_llm_index(sender, document, **kwargs):
"""
Add or update a document in the LLM index when it is created or updated.
"""
ai_config = AIConfig()
if ai_config.llm_index_enabled():
from documents.tasks import update_document_in_llm_index
update_document_in_llm_index.delay(document)
@receiver(models.signals.post_delete, sender=Document)
def delete_document_from_llm_index(sender, instance: Document, **kwargs):
"""
Delete a document from the LLM index when it is deleted.
"""
ai_config = AIConfig()
if ai_config.llm_index_enabled():
from documents.tasks import remove_document_from_llm_index
remove_document_from_llm_index.delay(instance)

View File

@ -6,8 +6,6 @@ import uuid
from pathlib import Path
from tempfile import TemporaryDirectory
import faiss
import llama_index.core.settings as llama_settings
import tqdm
from celery import Task
from celery import shared_task
@ -19,13 +17,6 @@ from django.db import transaction
from django.db.models.signals import post_save
from django.utils import timezone
from filelock import FileLock
from llama_index.core import Document as LlamaDocument
from llama_index.core import StorageContext
from llama_index.core import VectorStoreIndex
from llama_index.core.node_parser import SimpleNodeParser
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.vector_stores.faiss import FaissVectorStore
from whoosh.writing import AsyncWriter
from documents import index
@ -61,9 +52,10 @@ from documents.sanity_checker import SanityCheckFailedException
from documents.signals import document_updated
from documents.signals.handlers import cleanup_document_deletion
from documents.signals.handlers import run_workflows
from paperless.ai.embedding import build_llm_index_text
from paperless.ai.embedding import get_embedding_dim
from paperless.ai.embedding import get_embedding_model
from paperless.ai.indexing import llm_index_add_or_update_document
from paperless.ai.indexing import llm_index_remove_document
from paperless.ai.indexing import rebuild_llm_index
from paperless.config import AIConfig
if settings.AUDIT_LOG_ENABLED:
from auditlog.models import LogEntry
@ -251,6 +243,11 @@ def bulk_update_documents(document_ids):
for doc in documents:
index.update_document(writer, doc)
ai_config = AIConfig()
if ai_config.llm_index_enabled():
for doc in documents:
llm_index_add_or_update_document()
@shared_task
def update_document_content_maybe_archive_file(document_id):
@ -350,6 +347,10 @@ def update_document_content_maybe_archive_file(document_id):
with index.open_index_writer() as writer:
index.update_document(writer, document)
ai_config = AIConfig()
if ai_config.llm_index_enabled:
llm_index_add_or_update_document(document)
clear_document_caches(document.pk)
except Exception:
@ -511,60 +512,25 @@ def check_scheduled_workflows():
def llm_index_rebuild(*, progress_bar_disable=False, rebuild=False):
if rebuild:
shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
embed_model = get_embedding_model()
llama_settings.Settings.embed_model = embed_model
if rebuild or not settings.LLM_INDEX_DIR.exists():
embedding_dim = get_embedding_dim()
faiss_index = faiss.IndexFlatL2(embedding_dim)
vector_store = FaissVectorStore(faiss_index=faiss_index)
else:
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
docstore = SimpleDocumentStore()
index_store = SimpleIndexStore()
storage_context = StorageContext.from_defaults(
docstore=docstore,
index_store=index_store,
persist_dir=settings.LLM_INDEX_DIR,
vector_store=vector_store,
rebuild_llm_index(
progress_bar_disable=progress_bar_disable,
rebuild=rebuild,
)
parser = SimpleNodeParser()
nodes = []
for document in tqdm.tqdm(Document.objects.all(), disable=progress_bar_disable):
if not document.content:
continue
@shared_task
def update_document_in_llm_index(document):
llm_index_add_or_update_document(document)
text = build_llm_index_text(document)
metadata = {
"document_id": document.id,
"title": document.title,
"tags": [t.name for t in document.tags.all()],
"correspondent": document.correspondent.name
if document.correspondent
else None,
"document_type": document.document_type.name
if document.document_type
else None,
"created": document.created.isoformat() if document.created else None,
"added": document.added.isoformat() if document.added else None,
}
doc = LlamaDocument(text=text, metadata=metadata)
doc_nodes = parser.get_nodes_from_documents([doc])
nodes.extend(doc_nodes)
@shared_task
def remove_document_from_llm_index(document):
llm_index_remove_document(document)
index = VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
embed_model=embed_model,
)
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
# TODO: schedule to run periodically
@shared_task
def rebuild_llm_index_task():
from paperless.ai.indexing import rebuild_llm_index
rebuild_llm_index(rebuild=True)

View File

@ -7,7 +7,7 @@ from llama_index.core.query_engine import RetrieverQueryEngine
from documents.models import Document
from paperless.ai.client import AIClient
from paperless.ai.indexing import load_index
from paperless.ai.indexing import load_or_build_index
logger = logging.getLogger("paperless.ai.chat")
@ -24,7 +24,7 @@ CHAT_PROMPT_TMPL = PromptTemplate(
def stream_chat_with_documents(query_str: str, documents: list[Document]):
client = AIClient()
index = load_index()
index = load_or_build_index()
doc_ids = [doc.pk for doc in documents]

View File

@ -1,44 +1,209 @@
import logging
import shutil
import faiss
import llama_index.core.settings as llama_settings
import tqdm
from django.conf import settings
from llama_index.core import Document as LlamaDocument
from llama_index.core import StorageContext
from llama_index.core import VectorStoreIndex
from llama_index.core import load_index_from_storage
from llama_index.core.node_parser import SimpleNodeParser
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import BaseNode
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.vector_stores.faiss import FaissVectorStore
from documents.models import Document
from paperless.ai.embedding import build_llm_index_text
from paperless.ai.embedding import get_embedding_dim
from paperless.ai.embedding import get_embedding_model
logger = logging.getLogger("paperless.ai.indexing")
def load_index() -> VectorStoreIndex:
"""Loads the persisted LlamaIndex from disk."""
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
embed_model = get_embedding_model()
def get_or_create_storage_context(*, rebuild=False):
"""
Loads or creates the StorageContext (vector store, docstore, index store).
If rebuild=True, deletes and recreates everything.
"""
if rebuild:
shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
llama_settings.Settings.embed_model = embed_model
llama_settings.Settings.chunk_size = 512
if rebuild or not settings.LLM_INDEX_DIR.exists():
embedding_dim = get_embedding_dim()
faiss_index = faiss.IndexFlatL2(embedding_dim)
vector_store = FaissVectorStore(faiss_index=faiss_index)
docstore = SimpleDocumentStore()
index_store = SimpleIndexStore()
else:
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
docstore = SimpleDocumentStore.from_persist_dir(settings.LLM_INDEX_DIR)
index_store = SimpleIndexStore.from_persist_dir(settings.LLM_INDEX_DIR)
storage_context = StorageContext.from_defaults(
return StorageContext.from_defaults(
docstore=docstore,
index_store=index_store,
vector_store=vector_store,
persist_dir=settings.LLM_INDEX_DIR,
)
return load_index_from_storage(storage_context)
def get_vector_store_index(storage_context, embed_model):
"""
Returns a VectorStoreIndex given a storage context and embed model.
"""
return VectorStoreIndex(
storage_context=storage_context,
embed_model=embed_model,
)
def build_document_node(document) -> list[BaseNode]:
"""
Given a Document, returns parsed Nodes ready for indexing.
"""
if not document.content:
return []
text = build_llm_index_text(document)
metadata = {
"document_id": document.id,
"title": document.title,
"tags": [t.name for t in document.tags.all()],
"correspondent": document.correspondent.name
if document.correspondent
else None,
"document_type": document.document_type.name
if document.document_type
else None,
"created": document.created.isoformat() if document.created else None,
"added": document.added.isoformat() if document.added else None,
}
doc = LlamaDocument(text=text, metadata=metadata)
parser = SimpleNodeParser()
return parser.get_nodes_from_documents([doc])
def load_or_build_index(storage_context, embed_model, nodes=None):
"""
Load an existing VectorStoreIndex if present,
or build a new one using provided nodes if storage is empty.
"""
try:
return VectorStoreIndex(
storage_context=storage_context,
embed_model=embed_model,
)
except ValueError as e:
if "One of nodes, objects, or index_struct must be provided" in str(e):
if not nodes:
return None
return VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
embed_model=embed_model,
)
raise
def remove_existing_document_nodes(document, index):
"""
Removes existing documents from docstore for a given document from the index.
This is necessary because FAISS IndexFlatL2 is append-only.
"""
all_node_ids = list(index.docstore.docs.keys())
existing_nodes = [
node.node_id
for node in index.docstore.get_nodes(all_node_ids)
if node.metadata.get("document_id") == document.id
]
for node_id in existing_nodes:
# Delete from docstore, FAISS IndexFlatL2 are append-only
index.docstore.delete_document(node_id)
def rebuild_llm_index(*, progress_bar_disable=False, rebuild=False):
"""
Rebuilds the LLM index from scratch.
"""
embed_model = get_embedding_model()
llama_settings.Settings.embed_model = embed_model
storage_context = get_or_create_storage_context(rebuild=rebuild)
nodes = []
for document in tqdm.tqdm(Document.objects.all(), disable=progress_bar_disable):
document_nodes = build_document_node(document)
nodes.extend(document_nodes)
if not nodes:
raise RuntimeError(
"No nodes to index — check that documents are available and have content.",
)
VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
embed_model=embed_model,
)
storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
def llm_index_add_or_update_document(document):
"""
Adds or updates a document in the LLM index.
If the document already exists, it will be replaced.
"""
embed_model = get_embedding_model()
llama_settings.Settings.embed_model = embed_model
storage_context = get_or_create_storage_context(rebuild=False)
new_nodes = build_document_node(document)
index = load_or_build_index(storage_context, embed_model, nodes=new_nodes)
if index is None:
# Nothing to index
return
# Remove old nodes
remove_existing_document_nodes(document, index)
index.insert_nodes(new_nodes)
storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
def llm_index_remove_document(document):
embed_model = get_embedding_model()
llama_settings.embed_model = embed_model
storage_context = get_or_create_storage_context(rebuild=False)
index = load_or_build_index(storage_context, embed_model)
if index is None:
return # Nothing to remove
# Remove old nodes
remove_existing_document_nodes(document, index)
storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]:
"""Runs a similarity query and returns top-k similar Document objects."""
# Load the index
index = load_index()
"""
Runs a similarity query and returns top-k similar Document objects.
"""
index = load_or_build_index()
retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
# Build query from the document text
query_text = (document.title or "") + "\n" + (document.content or "")
# Query
results = retriever.retrieve(query_text)
# Each result.node.metadata["document_id"] should match our stored doc

View File

@ -199,3 +199,10 @@ class AIConfig(BaseConfig):
self.llm_model = app_config.llm_model or settings.LLM_MODEL
self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY
self.llm_url = app_config.llm_url or settings.LLM_URL
def llm_index_enabled(self) -> bool:
return (
self.ai_enabled
and self.llm_embedding_backend
and self.llm_embedding_backend
)

View File

@ -45,7 +45,7 @@ def mock_document():
def test_stream_chat_with_one_document_full_content(mock_document):
with (
patch("paperless.ai.chat.AIClient") as mock_client_cls,
patch("paperless.ai.chat.load_index") as mock_load_index,
patch("paperless.ai.chat.load_or_build_index") as mock_load_index,
patch(
"paperless.ai.chat.RetrieverQueryEngine.from_args",
) as mock_query_engine_cls,
@ -76,7 +76,7 @@ def test_stream_chat_with_one_document_full_content(mock_document):
def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
with (
patch("paperless.ai.chat.AIClient") as mock_client_cls,
patch("paperless.ai.chat.load_index") as mock_load_index,
patch("paperless.ai.chat.load_or_build_index") as mock_load_index,
patch(
"paperless.ai.chat.RetrieverQueryEngine.from_args",
) as mock_query_engine_cls,
@ -126,7 +126,7 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
def test_stream_chat_no_matching_nodes():
with (
patch("paperless.ai.chat.AIClient") as mock_client_cls,
patch("paperless.ai.chat.load_index") as mock_load_index,
patch("paperless.ai.chat.load_or_build_index") as mock_load_index,
):
mock_client = MagicMock()
mock_client_cls.return_value = mock_client

View File

@ -0,0 +1,144 @@
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from django.utils import timezone
from llama_index.core.base.embeddings.base import BaseEmbedding
from documents.models import Document
from paperless.ai import indexing
@pytest.fixture
def temp_llm_index_dir(tmp_path):
original_dir = indexing.settings.LLM_INDEX_DIR
indexing.settings.LLM_INDEX_DIR = tmp_path
yield tmp_path
indexing.settings.LLM_INDEX_DIR = original_dir
@pytest.fixture
def real_document(db):
return Document.objects.create(
title="Test Document",
content="This is some test content.",
added=timezone.now(),
)
@pytest.fixture
def mock_embed_model():
"""Mocks the embedding model."""
with patch("paperless.ai.indexing.get_embedding_model") as mock:
mock.return_value = FakeEmbedding()
yield mock
class FakeEmbedding(BaseEmbedding):
# TODO: maybe a better way to do this?
def _aget_query_embedding(self, query: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def _get_query_embedding(self, query: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def _get_text_embedding(self, text: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def get_query_embedding_dim(self) -> int:
return 384 # Match your real FAISS config
@pytest.mark.django_db
def test_build_document_node(real_document):
nodes = indexing.build_document_node(real_document)
assert len(nodes) > 0
assert nodes[0].metadata["document_id"] == real_document.id
@pytest.mark.django_db
def test_rebuild_llm_index(
temp_llm_index_dir,
real_document,
mock_embed_model,
):
with patch("documents.models.Document.objects.all") as mock_all:
mock_all.return_value = [real_document]
indexing.rebuild_llm_index(rebuild=True)
assert any(temp_llm_index_dir.glob("*.json"))
@pytest.mark.django_db
def test_add_or_update_document_updates_existing_entry(
temp_llm_index_dir,
real_document,
mock_embed_model,
):
indexing.rebuild_llm_index(rebuild=True)
indexing.llm_index_add_or_update_document(real_document)
assert any(temp_llm_index_dir.glob("*.json"))
@pytest.mark.django_db
def test_remove_document_deletes_node_from_docstore(
temp_llm_index_dir,
real_document,
mock_embed_model,
):
indexing.rebuild_llm_index(rebuild=True)
indexing.llm_index_add_or_update_document(real_document)
indexing.llm_index_remove_document(real_document)
assert any(temp_llm_index_dir.glob("*.json"))
@pytest.mark.django_db
def test_rebuild_llm_index_no_documents(
temp_llm_index_dir,
mock_embed_model,
):
with patch("documents.models.Document.objects.all") as mock_all:
mock_all.return_value = []
with pytest.raises(RuntimeError, match="No nodes to index"):
indexing.rebuild_llm_index(rebuild=True)
def test_query_similar_documents(
temp_llm_index_dir,
real_document,
):
with (
patch("paperless.ai.indexing.load_or_build_index") as mock_load_or_build_index,
patch("paperless.ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
patch("paperless.ai.indexing.Document.objects.filter") as mock_filter,
):
mock_index = MagicMock()
mock_load_or_build_index.return_value = mock_index
mock_retriever = MagicMock()
mock_retriever_cls.return_value = mock_retriever
mock_node1 = MagicMock()
mock_node1.metadata = {"document_id": 1}
mock_node2 = MagicMock()
mock_node2.metadata = {"document_id": 2}
mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
mock_filtered_docs = [MagicMock(pk=1), MagicMock(pk=2)]
mock_filter.return_value = mock_filtered_docs
result = indexing.query_similar_documents(real_document, top_k=3)
mock_load_or_build_index.assert_called_once()
mock_retriever_cls.assert_called_once_with(index=mock_index, similarity_top_k=3)
mock_retriever.retrieve.assert_called_once_with(
"Test Document\nThis is some test content.",
)
mock_filter.assert_called_once_with(pk__in=[1, 2])
assert result == mock_filtered_docs

View File

@ -2,14 +2,11 @@ from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from llama_index.core.base.embeddings.base import BaseEmbedding
from documents.models import Document
from paperless.ai.embedding import build_llm_index_text
from paperless.ai.embedding import get_embedding_dim
from paperless.ai.embedding import get_embedding_model
from paperless.ai.indexing import load_index
from paperless.ai.indexing import query_similar_documents
from paperless.ai.rag import get_context_for_document
from paperless.models import LLMEmbeddingBackend
@ -182,93 +179,3 @@ def test_build_llm_index_text(mock_document):
assert "Notes: Note1,Note2" in result
assert "Content:\n\nThis is the document content." in result
assert "Custom Field - Field1: Value1\nCustom Field - Field2: Value2" in result
# Indexing
@pytest.fixture
def mock_settings(settings):
settings.LLM_INDEX_DIR = "/fake/path"
return settings
class FakeEmbedding(BaseEmbedding):
# TODO: gotta be a better way to do this
def _aget_query_embedding(self, query: str) -> list[float]:
return [0.1, 0.2, 0.3]
def _get_query_embedding(self, query: str) -> list[float]:
return [0.1, 0.2, 0.3]
def _get_text_embedding(self, text: str) -> list[float]:
return [0.1, 0.2, 0.3]
def test_load_index(mock_settings):
with (
patch("paperless.ai.indexing.FaissVectorStore.from_persist_dir") as mock_faiss,
patch("paperless.ai.indexing.get_embedding_model") as mock_get_embed_model,
patch(
"paperless.ai.indexing.StorageContext.from_defaults",
) as mock_storage_context,
patch("paperless.ai.indexing.load_index_from_storage") as mock_load_index,
):
# Setup mocks
mock_vector_store = MagicMock()
mock_storage = MagicMock()
mock_index = MagicMock()
mock_faiss.return_value = mock_vector_store
mock_storage_context.return_value = mock_storage
mock_load_index.return_value = mock_index
mock_get_embed_model.return_value = FakeEmbedding()
# Act
result = load_index()
# Assert
mock_faiss.assert_called_once_with("/fake/path")
mock_get_embed_model.assert_called_once()
mock_storage_context.assert_called_once_with(
vector_store=mock_vector_store,
persist_dir="/fake/path",
)
mock_load_index.assert_called_once_with(mock_storage)
assert result == mock_index
def test_query_similar_documents(mock_document):
with (
patch("paperless.ai.indexing.load_index") as mock_load_index_func,
patch("paperless.ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
patch("paperless.ai.indexing.Document.objects.filter") as mock_filter,
):
# Setup mocks
mock_index = MagicMock()
mock_load_index_func.return_value = mock_index
mock_retriever = MagicMock()
mock_retriever_cls.return_value = mock_retriever
mock_node1 = MagicMock()
mock_node1.metadata = {"document_id": 1}
mock_node2 = MagicMock()
mock_node2.metadata = {"document_id": 2}
mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
mock_filtered_docs = [MagicMock(pk=1), MagicMock(pk=2)]
mock_filter.return_value = mock_filtered_docs
result = query_similar_documents(mock_document, top_k=3)
mock_load_index_func.assert_called_once()
mock_retriever_cls.assert_called_once_with(index=mock_index, similarity_top_k=3)
mock_retriever.retrieve.assert_called_once_with(
"Test Title\nThis is the document content.",
)
mock_filter.assert_called_once_with(pk__in=[1, 2])
assert result == mock_filtered_docs