mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-05-24 02:02:23 -04:00
Handle doc updates, refactor
This commit is contained in:
parent
2cafe4a2c0
commit
fd8ffa62b0
@ -11,6 +11,7 @@ class DocumentsConfig(AppConfig):
|
|||||||
from documents.signals import document_consumption_finished
|
from documents.signals import document_consumption_finished
|
||||||
from documents.signals import document_updated
|
from documents.signals import document_updated
|
||||||
from documents.signals.handlers import add_inbox_tags
|
from documents.signals.handlers import add_inbox_tags
|
||||||
|
from documents.signals.handlers import add_or_update_document_in_llm_index
|
||||||
from documents.signals.handlers import add_to_index
|
from documents.signals.handlers import add_to_index
|
||||||
from documents.signals.handlers import run_workflows_added
|
from documents.signals.handlers import run_workflows_added
|
||||||
from documents.signals.handlers import run_workflows_updated
|
from documents.signals.handlers import run_workflows_updated
|
||||||
@ -26,6 +27,7 @@ class DocumentsConfig(AppConfig):
|
|||||||
document_consumption_finished.connect(set_storage_path)
|
document_consumption_finished.connect(set_storage_path)
|
||||||
document_consumption_finished.connect(add_to_index)
|
document_consumption_finished.connect(add_to_index)
|
||||||
document_consumption_finished.connect(run_workflows_added)
|
document_consumption_finished.connect(run_workflows_added)
|
||||||
|
document_consumption_finished.connect(add_or_update_document_in_llm_index)
|
||||||
document_updated.connect(run_workflows_updated)
|
document_updated.connect(run_workflows_updated)
|
||||||
|
|
||||||
import documents.schema # noqa: F401
|
import documents.schema # noqa: F401
|
||||||
|
@ -46,6 +46,7 @@ from documents.models import WorkflowTrigger
|
|||||||
from documents.permissions import get_objects_for_user_owner_aware
|
from documents.permissions import get_objects_for_user_owner_aware
|
||||||
from documents.permissions import set_permissions_for_object
|
from documents.permissions import set_permissions_for_object
|
||||||
from documents.templating.workflows import parse_w_workflow_placeholders
|
from documents.templating.workflows import parse_w_workflow_placeholders
|
||||||
|
from paperless.config import AIConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -1403,3 +1404,26 @@ def task_failure_handler(
|
|||||||
task_instance.save()
|
task_instance.save()
|
||||||
except Exception: # pragma: no cover
|
except Exception: # pragma: no cover
|
||||||
logger.exception("Updating PaperlessTask failed")
|
logger.exception("Updating PaperlessTask failed")
|
||||||
|
|
||||||
|
|
||||||
|
def add_or_update_document_in_llm_index(sender, document, **kwargs):
|
||||||
|
"""
|
||||||
|
Add or update a document in the LLM index when it is created or updated.
|
||||||
|
"""
|
||||||
|
ai_config = AIConfig()
|
||||||
|
if ai_config.llm_index_enabled():
|
||||||
|
from documents.tasks import update_document_in_llm_index
|
||||||
|
|
||||||
|
update_document_in_llm_index.delay(document)
|
||||||
|
|
||||||
|
|
||||||
|
@receiver(models.signals.post_delete, sender=Document)
|
||||||
|
def delete_document_from_llm_index(sender, instance: Document, **kwargs):
|
||||||
|
"""
|
||||||
|
Delete a document from the LLM index when it is deleted.
|
||||||
|
"""
|
||||||
|
ai_config = AIConfig()
|
||||||
|
if ai_config.llm_index_enabled():
|
||||||
|
from documents.tasks import remove_document_from_llm_index
|
||||||
|
|
||||||
|
remove_document_from_llm_index.delay(instance)
|
||||||
|
@ -6,8 +6,6 @@ import uuid
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
import faiss
|
|
||||||
import llama_index.core.settings as llama_settings
|
|
||||||
import tqdm
|
import tqdm
|
||||||
from celery import Task
|
from celery import Task
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
@ -19,13 +17,6 @@ from django.db import transaction
|
|||||||
from django.db.models.signals import post_save
|
from django.db.models.signals import post_save
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
from llama_index.core import Document as LlamaDocument
|
|
||||||
from llama_index.core import StorageContext
|
|
||||||
from llama_index.core import VectorStoreIndex
|
|
||||||
from llama_index.core.node_parser import SimpleNodeParser
|
|
||||||
from llama_index.core.storage.docstore import SimpleDocumentStore
|
|
||||||
from llama_index.core.storage.index_store import SimpleIndexStore
|
|
||||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
|
||||||
from whoosh.writing import AsyncWriter
|
from whoosh.writing import AsyncWriter
|
||||||
|
|
||||||
from documents import index
|
from documents import index
|
||||||
@ -61,9 +52,10 @@ from documents.sanity_checker import SanityCheckFailedException
|
|||||||
from documents.signals import document_updated
|
from documents.signals import document_updated
|
||||||
from documents.signals.handlers import cleanup_document_deletion
|
from documents.signals.handlers import cleanup_document_deletion
|
||||||
from documents.signals.handlers import run_workflows
|
from documents.signals.handlers import run_workflows
|
||||||
from paperless.ai.embedding import build_llm_index_text
|
from paperless.ai.indexing import llm_index_add_or_update_document
|
||||||
from paperless.ai.embedding import get_embedding_dim
|
from paperless.ai.indexing import llm_index_remove_document
|
||||||
from paperless.ai.embedding import get_embedding_model
|
from paperless.ai.indexing import rebuild_llm_index
|
||||||
|
from paperless.config import AIConfig
|
||||||
|
|
||||||
if settings.AUDIT_LOG_ENABLED:
|
if settings.AUDIT_LOG_ENABLED:
|
||||||
from auditlog.models import LogEntry
|
from auditlog.models import LogEntry
|
||||||
@ -251,6 +243,11 @@ def bulk_update_documents(document_ids):
|
|||||||
for doc in documents:
|
for doc in documents:
|
||||||
index.update_document(writer, doc)
|
index.update_document(writer, doc)
|
||||||
|
|
||||||
|
ai_config = AIConfig()
|
||||||
|
if ai_config.llm_index_enabled():
|
||||||
|
for doc in documents:
|
||||||
|
llm_index_add_or_update_document()
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
def update_document_content_maybe_archive_file(document_id):
|
def update_document_content_maybe_archive_file(document_id):
|
||||||
@ -350,6 +347,10 @@ def update_document_content_maybe_archive_file(document_id):
|
|||||||
with index.open_index_writer() as writer:
|
with index.open_index_writer() as writer:
|
||||||
index.update_document(writer, document)
|
index.update_document(writer, document)
|
||||||
|
|
||||||
|
ai_config = AIConfig()
|
||||||
|
if ai_config.llm_index_enabled:
|
||||||
|
llm_index_add_or_update_document(document)
|
||||||
|
|
||||||
clear_document_caches(document.pk)
|
clear_document_caches(document.pk)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -511,60 +512,25 @@ def check_scheduled_workflows():
|
|||||||
|
|
||||||
|
|
||||||
def llm_index_rebuild(*, progress_bar_disable=False, rebuild=False):
|
def llm_index_rebuild(*, progress_bar_disable=False, rebuild=False):
|
||||||
if rebuild:
|
rebuild_llm_index(
|
||||||
shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
|
progress_bar_disable=progress_bar_disable,
|
||||||
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
rebuild=rebuild,
|
||||||
|
|
||||||
embed_model = get_embedding_model()
|
|
||||||
llama_settings.Settings.embed_model = embed_model
|
|
||||||
|
|
||||||
if rebuild or not settings.LLM_INDEX_DIR.exists():
|
|
||||||
embedding_dim = get_embedding_dim()
|
|
||||||
faiss_index = faiss.IndexFlatL2(embedding_dim)
|
|
||||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
|
||||||
else:
|
|
||||||
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
|
||||||
|
|
||||||
docstore = SimpleDocumentStore()
|
|
||||||
index_store = SimpleIndexStore()
|
|
||||||
|
|
||||||
storage_context = StorageContext.from_defaults(
|
|
||||||
docstore=docstore,
|
|
||||||
index_store=index_store,
|
|
||||||
persist_dir=settings.LLM_INDEX_DIR,
|
|
||||||
vector_store=vector_store,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser = SimpleNodeParser()
|
|
||||||
nodes = []
|
|
||||||
|
|
||||||
for document in tqdm.tqdm(Document.objects.all(), disable=progress_bar_disable):
|
@shared_task
|
||||||
if not document.content:
|
def update_document_in_llm_index(document):
|
||||||
continue
|
llm_index_add_or_update_document(document)
|
||||||
|
|
||||||
text = build_llm_index_text(document)
|
|
||||||
metadata = {
|
|
||||||
"document_id": document.id,
|
|
||||||
"title": document.title,
|
|
||||||
"tags": [t.name for t in document.tags.all()],
|
|
||||||
"correspondent": document.correspondent.name
|
|
||||||
if document.correspondent
|
|
||||||
else None,
|
|
||||||
"document_type": document.document_type.name
|
|
||||||
if document.document_type
|
|
||||||
else None,
|
|
||||||
"created": document.created.isoformat() if document.created else None,
|
|
||||||
"added": document.added.isoformat() if document.added else None,
|
|
||||||
}
|
|
||||||
|
|
||||||
doc = LlamaDocument(text=text, metadata=metadata)
|
@shared_task
|
||||||
doc_nodes = parser.get_nodes_from_documents([doc])
|
def remove_document_from_llm_index(document):
|
||||||
nodes.extend(doc_nodes)
|
llm_index_remove_document(document)
|
||||||
|
|
||||||
index = VectorStoreIndex(
|
|
||||||
nodes=nodes,
|
|
||||||
storage_context=storage_context,
|
|
||||||
embed_model=embed_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
# TODO: schedule to run periodically
|
||||||
|
@shared_task
|
||||||
|
def rebuild_llm_index_task():
|
||||||
|
from paperless.ai.indexing import rebuild_llm_index
|
||||||
|
|
||||||
|
rebuild_llm_index(rebuild=True)
|
||||||
|
@ -7,7 +7,7 @@ from llama_index.core.query_engine import RetrieverQueryEngine
|
|||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from paperless.ai.client import AIClient
|
from paperless.ai.client import AIClient
|
||||||
from paperless.ai.indexing import load_index
|
from paperless.ai.indexing import load_or_build_index
|
||||||
|
|
||||||
logger = logging.getLogger("paperless.ai.chat")
|
logger = logging.getLogger("paperless.ai.chat")
|
||||||
|
|
||||||
@ -24,7 +24,7 @@ CHAT_PROMPT_TMPL = PromptTemplate(
|
|||||||
|
|
||||||
def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||||
client = AIClient()
|
client = AIClient()
|
||||||
index = load_index()
|
index = load_or_build_index()
|
||||||
|
|
||||||
doc_ids = [doc.pk for doc in documents]
|
doc_ids = [doc.pk for doc in documents]
|
||||||
|
|
||||||
|
@ -1,44 +1,209 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import faiss
|
||||||
import llama_index.core.settings as llama_settings
|
import llama_index.core.settings as llama_settings
|
||||||
|
import tqdm
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
from llama_index.core import Document as LlamaDocument
|
||||||
from llama_index.core import StorageContext
|
from llama_index.core import StorageContext
|
||||||
from llama_index.core import VectorStoreIndex
|
from llama_index.core import VectorStoreIndex
|
||||||
from llama_index.core import load_index_from_storage
|
from llama_index.core.node_parser import SimpleNodeParser
|
||||||
from llama_index.core.retrievers import VectorIndexRetriever
|
from llama_index.core.retrievers import VectorIndexRetriever
|
||||||
|
from llama_index.core.schema import BaseNode
|
||||||
|
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||||
|
from llama_index.core.storage.index_store import SimpleIndexStore
|
||||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
|
from paperless.ai.embedding import build_llm_index_text
|
||||||
|
from paperless.ai.embedding import get_embedding_dim
|
||||||
from paperless.ai.embedding import get_embedding_model
|
from paperless.ai.embedding import get_embedding_model
|
||||||
|
|
||||||
logger = logging.getLogger("paperless.ai.indexing")
|
logger = logging.getLogger("paperless.ai.indexing")
|
||||||
|
|
||||||
|
|
||||||
def load_index() -> VectorStoreIndex:
|
def get_or_create_storage_context(*, rebuild=False):
|
||||||
"""Loads the persisted LlamaIndex from disk."""
|
"""
|
||||||
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
Loads or creates the StorageContext (vector store, docstore, index store).
|
||||||
embed_model = get_embedding_model()
|
If rebuild=True, deletes and recreates everything.
|
||||||
|
"""
|
||||||
|
if rebuild:
|
||||||
|
shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
|
||||||
|
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
llama_settings.Settings.embed_model = embed_model
|
if rebuild or not settings.LLM_INDEX_DIR.exists():
|
||||||
llama_settings.Settings.chunk_size = 512
|
embedding_dim = get_embedding_dim()
|
||||||
|
faiss_index = faiss.IndexFlatL2(embedding_dim)
|
||||||
|
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||||
|
docstore = SimpleDocumentStore()
|
||||||
|
index_store = SimpleIndexStore()
|
||||||
|
else:
|
||||||
|
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||||
|
docstore = SimpleDocumentStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||||
|
index_store = SimpleIndexStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||||
|
|
||||||
storage_context = StorageContext.from_defaults(
|
return StorageContext.from_defaults(
|
||||||
|
docstore=docstore,
|
||||||
|
index_store=index_store,
|
||||||
vector_store=vector_store,
|
vector_store=vector_store,
|
||||||
persist_dir=settings.LLM_INDEX_DIR,
|
persist_dir=settings.LLM_INDEX_DIR,
|
||||||
)
|
)
|
||||||
return load_index_from_storage(storage_context)
|
|
||||||
|
|
||||||
|
def get_vector_store_index(storage_context, embed_model):
|
||||||
|
"""
|
||||||
|
Returns a VectorStoreIndex given a storage context and embed model.
|
||||||
|
"""
|
||||||
|
return VectorStoreIndex(
|
||||||
|
storage_context=storage_context,
|
||||||
|
embed_model=embed_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_document_node(document) -> list[BaseNode]:
|
||||||
|
"""
|
||||||
|
Given a Document, returns parsed Nodes ready for indexing.
|
||||||
|
"""
|
||||||
|
if not document.content:
|
||||||
|
return []
|
||||||
|
|
||||||
|
text = build_llm_index_text(document)
|
||||||
|
metadata = {
|
||||||
|
"document_id": document.id,
|
||||||
|
"title": document.title,
|
||||||
|
"tags": [t.name for t in document.tags.all()],
|
||||||
|
"correspondent": document.correspondent.name
|
||||||
|
if document.correspondent
|
||||||
|
else None,
|
||||||
|
"document_type": document.document_type.name
|
||||||
|
if document.document_type
|
||||||
|
else None,
|
||||||
|
"created": document.created.isoformat() if document.created else None,
|
||||||
|
"added": document.added.isoformat() if document.added else None,
|
||||||
|
}
|
||||||
|
doc = LlamaDocument(text=text, metadata=metadata)
|
||||||
|
parser = SimpleNodeParser()
|
||||||
|
return parser.get_nodes_from_documents([doc])
|
||||||
|
|
||||||
|
|
||||||
|
def load_or_build_index(storage_context, embed_model, nodes=None):
|
||||||
|
"""
|
||||||
|
Load an existing VectorStoreIndex if present,
|
||||||
|
or build a new one using provided nodes if storage is empty.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return VectorStoreIndex(
|
||||||
|
storage_context=storage_context,
|
||||||
|
embed_model=embed_model,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
if "One of nodes, objects, or index_struct must be provided" in str(e):
|
||||||
|
if not nodes:
|
||||||
|
return None
|
||||||
|
return VectorStoreIndex(
|
||||||
|
nodes=nodes,
|
||||||
|
storage_context=storage_context,
|
||||||
|
embed_model=embed_model,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def remove_existing_document_nodes(document, index):
|
||||||
|
"""
|
||||||
|
Removes existing documents from docstore for a given document from the index.
|
||||||
|
This is necessary because FAISS IndexFlatL2 is append-only.
|
||||||
|
"""
|
||||||
|
all_node_ids = list(index.docstore.docs.keys())
|
||||||
|
existing_nodes = [
|
||||||
|
node.node_id
|
||||||
|
for node in index.docstore.get_nodes(all_node_ids)
|
||||||
|
if node.metadata.get("document_id") == document.id
|
||||||
|
]
|
||||||
|
for node_id in existing_nodes:
|
||||||
|
# Delete from docstore, FAISS IndexFlatL2 are append-only
|
||||||
|
index.docstore.delete_document(node_id)
|
||||||
|
|
||||||
|
|
||||||
|
def rebuild_llm_index(*, progress_bar_disable=False, rebuild=False):
|
||||||
|
"""
|
||||||
|
Rebuilds the LLM index from scratch.
|
||||||
|
"""
|
||||||
|
embed_model = get_embedding_model()
|
||||||
|
llama_settings.Settings.embed_model = embed_model
|
||||||
|
|
||||||
|
storage_context = get_or_create_storage_context(rebuild=rebuild)
|
||||||
|
|
||||||
|
nodes = []
|
||||||
|
|
||||||
|
for document in tqdm.tqdm(Document.objects.all(), disable=progress_bar_disable):
|
||||||
|
document_nodes = build_document_node(document)
|
||||||
|
nodes.extend(document_nodes)
|
||||||
|
|
||||||
|
if not nodes:
|
||||||
|
raise RuntimeError(
|
||||||
|
"No nodes to index — check that documents are available and have content.",
|
||||||
|
)
|
||||||
|
|
||||||
|
VectorStoreIndex(
|
||||||
|
nodes=nodes,
|
||||||
|
storage_context=storage_context,
|
||||||
|
embed_model=embed_model,
|
||||||
|
)
|
||||||
|
storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
def llm_index_add_or_update_document(document):
|
||||||
|
"""
|
||||||
|
Adds or updates a document in the LLM index.
|
||||||
|
If the document already exists, it will be replaced.
|
||||||
|
"""
|
||||||
|
embed_model = get_embedding_model()
|
||||||
|
llama_settings.Settings.embed_model = embed_model
|
||||||
|
|
||||||
|
storage_context = get_or_create_storage_context(rebuild=False)
|
||||||
|
|
||||||
|
new_nodes = build_document_node(document)
|
||||||
|
|
||||||
|
index = load_or_build_index(storage_context, embed_model, nodes=new_nodes)
|
||||||
|
|
||||||
|
if index is None:
|
||||||
|
# Nothing to index
|
||||||
|
return
|
||||||
|
|
||||||
|
# Remove old nodes
|
||||||
|
remove_existing_document_nodes(document, index)
|
||||||
|
|
||||||
|
index.insert_nodes(new_nodes)
|
||||||
|
|
||||||
|
storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
def llm_index_remove_document(document):
|
||||||
|
embed_model = get_embedding_model()
|
||||||
|
llama_settings.embed_model = embed_model
|
||||||
|
|
||||||
|
storage_context = get_or_create_storage_context(rebuild=False)
|
||||||
|
|
||||||
|
index = load_or_build_index(storage_context, embed_model)
|
||||||
|
if index is None:
|
||||||
|
return # Nothing to remove
|
||||||
|
|
||||||
|
# Remove old nodes
|
||||||
|
remove_existing_document_nodes(document, index)
|
||||||
|
|
||||||
|
storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||||
|
|
||||||
|
|
||||||
def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]:
|
def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]:
|
||||||
"""Runs a similarity query and returns top-k similar Document objects."""
|
"""
|
||||||
# Load the index
|
Runs a similarity query and returns top-k similar Document objects.
|
||||||
index = load_index()
|
"""
|
||||||
|
index = load_or_build_index()
|
||||||
retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
|
retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
|
||||||
|
|
||||||
# Build query from the document text
|
# Build query from the document text
|
||||||
query_text = (document.title or "") + "\n" + (document.content or "")
|
query_text = (document.title or "") + "\n" + (document.content or "")
|
||||||
|
|
||||||
# Query
|
|
||||||
results = retriever.retrieve(query_text)
|
results = retriever.retrieve(query_text)
|
||||||
|
|
||||||
# Each result.node.metadata["document_id"] should match our stored doc
|
# Each result.node.metadata["document_id"] should match our stored doc
|
||||||
|
@ -199,3 +199,10 @@ class AIConfig(BaseConfig):
|
|||||||
self.llm_model = app_config.llm_model or settings.LLM_MODEL
|
self.llm_model = app_config.llm_model or settings.LLM_MODEL
|
||||||
self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY
|
self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY
|
||||||
self.llm_url = app_config.llm_url or settings.LLM_URL
|
self.llm_url = app_config.llm_url or settings.LLM_URL
|
||||||
|
|
||||||
|
def llm_index_enabled(self) -> bool:
|
||||||
|
return (
|
||||||
|
self.ai_enabled
|
||||||
|
and self.llm_embedding_backend
|
||||||
|
and self.llm_embedding_backend
|
||||||
|
)
|
||||||
|
@ -45,7 +45,7 @@ def mock_document():
|
|||||||
def test_stream_chat_with_one_document_full_content(mock_document):
|
def test_stream_chat_with_one_document_full_content(mock_document):
|
||||||
with (
|
with (
|
||||||
patch("paperless.ai.chat.AIClient") as mock_client_cls,
|
patch("paperless.ai.chat.AIClient") as mock_client_cls,
|
||||||
patch("paperless.ai.chat.load_index") as mock_load_index,
|
patch("paperless.ai.chat.load_or_build_index") as mock_load_index,
|
||||||
patch(
|
patch(
|
||||||
"paperless.ai.chat.RetrieverQueryEngine.from_args",
|
"paperless.ai.chat.RetrieverQueryEngine.from_args",
|
||||||
) as mock_query_engine_cls,
|
) as mock_query_engine_cls,
|
||||||
@ -76,7 +76,7 @@ def test_stream_chat_with_one_document_full_content(mock_document):
|
|||||||
def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
|
def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
|
||||||
with (
|
with (
|
||||||
patch("paperless.ai.chat.AIClient") as mock_client_cls,
|
patch("paperless.ai.chat.AIClient") as mock_client_cls,
|
||||||
patch("paperless.ai.chat.load_index") as mock_load_index,
|
patch("paperless.ai.chat.load_or_build_index") as mock_load_index,
|
||||||
patch(
|
patch(
|
||||||
"paperless.ai.chat.RetrieverQueryEngine.from_args",
|
"paperless.ai.chat.RetrieverQueryEngine.from_args",
|
||||||
) as mock_query_engine_cls,
|
) as mock_query_engine_cls,
|
||||||
@ -126,7 +126,7 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
|
|||||||
def test_stream_chat_no_matching_nodes():
|
def test_stream_chat_no_matching_nodes():
|
||||||
with (
|
with (
|
||||||
patch("paperless.ai.chat.AIClient") as mock_client_cls,
|
patch("paperless.ai.chat.AIClient") as mock_client_cls,
|
||||||
patch("paperless.ai.chat.load_index") as mock_load_index,
|
patch("paperless.ai.chat.load_or_build_index") as mock_load_index,
|
||||||
):
|
):
|
||||||
mock_client = MagicMock()
|
mock_client = MagicMock()
|
||||||
mock_client_cls.return_value = mock_client
|
mock_client_cls.return_value = mock_client
|
||||||
|
144
src/paperless/tests/test_ai_indexing.py
Normal file
144
src/paperless/tests/test_ai_indexing.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
from unittest.mock import MagicMock
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from django.utils import timezone
|
||||||
|
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||||
|
|
||||||
|
from documents.models import Document
|
||||||
|
from paperless.ai import indexing
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_llm_index_dir(tmp_path):
|
||||||
|
original_dir = indexing.settings.LLM_INDEX_DIR
|
||||||
|
indexing.settings.LLM_INDEX_DIR = tmp_path
|
||||||
|
yield tmp_path
|
||||||
|
indexing.settings.LLM_INDEX_DIR = original_dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def real_document(db):
|
||||||
|
return Document.objects.create(
|
||||||
|
title="Test Document",
|
||||||
|
content="This is some test content.",
|
||||||
|
added=timezone.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embed_model():
|
||||||
|
"""Mocks the embedding model."""
|
||||||
|
with patch("paperless.ai.indexing.get_embedding_model") as mock:
|
||||||
|
mock.return_value = FakeEmbedding()
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
class FakeEmbedding(BaseEmbedding):
|
||||||
|
# TODO: maybe a better way to do this?
|
||||||
|
def _aget_query_embedding(self, query: str) -> list[float]:
|
||||||
|
return [0.1] * self.get_query_embedding_dim()
|
||||||
|
|
||||||
|
def _get_query_embedding(self, query: str) -> list[float]:
|
||||||
|
return [0.1] * self.get_query_embedding_dim()
|
||||||
|
|
||||||
|
def _get_text_embedding(self, text: str) -> list[float]:
|
||||||
|
return [0.1] * self.get_query_embedding_dim()
|
||||||
|
|
||||||
|
def get_query_embedding_dim(self) -> int:
|
||||||
|
return 384 # Match your real FAISS config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_build_document_node(real_document):
|
||||||
|
nodes = indexing.build_document_node(real_document)
|
||||||
|
assert len(nodes) > 0
|
||||||
|
assert nodes[0].metadata["document_id"] == real_document.id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_rebuild_llm_index(
|
||||||
|
temp_llm_index_dir,
|
||||||
|
real_document,
|
||||||
|
mock_embed_model,
|
||||||
|
):
|
||||||
|
with patch("documents.models.Document.objects.all") as mock_all:
|
||||||
|
mock_all.return_value = [real_document]
|
||||||
|
indexing.rebuild_llm_index(rebuild=True)
|
||||||
|
|
||||||
|
assert any(temp_llm_index_dir.glob("*.json"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_add_or_update_document_updates_existing_entry(
|
||||||
|
temp_llm_index_dir,
|
||||||
|
real_document,
|
||||||
|
mock_embed_model,
|
||||||
|
):
|
||||||
|
indexing.rebuild_llm_index(rebuild=True)
|
||||||
|
indexing.llm_index_add_or_update_document(real_document)
|
||||||
|
|
||||||
|
assert any(temp_llm_index_dir.glob("*.json"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_remove_document_deletes_node_from_docstore(
|
||||||
|
temp_llm_index_dir,
|
||||||
|
real_document,
|
||||||
|
mock_embed_model,
|
||||||
|
):
|
||||||
|
indexing.rebuild_llm_index(rebuild=True)
|
||||||
|
indexing.llm_index_add_or_update_document(real_document)
|
||||||
|
indexing.llm_index_remove_document(real_document)
|
||||||
|
|
||||||
|
assert any(temp_llm_index_dir.glob("*.json"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_rebuild_llm_index_no_documents(
|
||||||
|
temp_llm_index_dir,
|
||||||
|
mock_embed_model,
|
||||||
|
):
|
||||||
|
with patch("documents.models.Document.objects.all") as mock_all:
|
||||||
|
mock_all.return_value = []
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="No nodes to index"):
|
||||||
|
indexing.rebuild_llm_index(rebuild=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_similar_documents(
|
||||||
|
temp_llm_index_dir,
|
||||||
|
real_document,
|
||||||
|
):
|
||||||
|
with (
|
||||||
|
patch("paperless.ai.indexing.load_or_build_index") as mock_load_or_build_index,
|
||||||
|
patch("paperless.ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
|
||||||
|
patch("paperless.ai.indexing.Document.objects.filter") as mock_filter,
|
||||||
|
):
|
||||||
|
mock_index = MagicMock()
|
||||||
|
mock_load_or_build_index.return_value = mock_index
|
||||||
|
|
||||||
|
mock_retriever = MagicMock()
|
||||||
|
mock_retriever_cls.return_value = mock_retriever
|
||||||
|
|
||||||
|
mock_node1 = MagicMock()
|
||||||
|
mock_node1.metadata = {"document_id": 1}
|
||||||
|
|
||||||
|
mock_node2 = MagicMock()
|
||||||
|
mock_node2.metadata = {"document_id": 2}
|
||||||
|
|
||||||
|
mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
|
||||||
|
|
||||||
|
mock_filtered_docs = [MagicMock(pk=1), MagicMock(pk=2)]
|
||||||
|
mock_filter.return_value = mock_filtered_docs
|
||||||
|
|
||||||
|
result = indexing.query_similar_documents(real_document, top_k=3)
|
||||||
|
|
||||||
|
mock_load_or_build_index.assert_called_once()
|
||||||
|
mock_retriever_cls.assert_called_once_with(index=mock_index, similarity_top_k=3)
|
||||||
|
mock_retriever.retrieve.assert_called_once_with(
|
||||||
|
"Test Document\nThis is some test content.",
|
||||||
|
)
|
||||||
|
mock_filter.assert_called_once_with(pk__in=[1, 2])
|
||||||
|
|
||||||
|
assert result == mock_filtered_docs
|
@ -2,14 +2,11 @@ from unittest.mock import MagicMock
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from paperless.ai.embedding import build_llm_index_text
|
from paperless.ai.embedding import build_llm_index_text
|
||||||
from paperless.ai.embedding import get_embedding_dim
|
from paperless.ai.embedding import get_embedding_dim
|
||||||
from paperless.ai.embedding import get_embedding_model
|
from paperless.ai.embedding import get_embedding_model
|
||||||
from paperless.ai.indexing import load_index
|
|
||||||
from paperless.ai.indexing import query_similar_documents
|
|
||||||
from paperless.ai.rag import get_context_for_document
|
from paperless.ai.rag import get_context_for_document
|
||||||
from paperless.models import LLMEmbeddingBackend
|
from paperless.models import LLMEmbeddingBackend
|
||||||
|
|
||||||
@ -182,93 +179,3 @@ def test_build_llm_index_text(mock_document):
|
|||||||
assert "Notes: Note1,Note2" in result
|
assert "Notes: Note1,Note2" in result
|
||||||
assert "Content:\n\nThis is the document content." in result
|
assert "Content:\n\nThis is the document content." in result
|
||||||
assert "Custom Field - Field1: Value1\nCustom Field - Field2: Value2" in result
|
assert "Custom Field - Field1: Value1\nCustom Field - Field2: Value2" in result
|
||||||
|
|
||||||
|
|
||||||
# Indexing
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_settings(settings):
|
|
||||||
settings.LLM_INDEX_DIR = "/fake/path"
|
|
||||||
return settings
|
|
||||||
|
|
||||||
|
|
||||||
class FakeEmbedding(BaseEmbedding):
|
|
||||||
# TODO: gotta be a better way to do this
|
|
||||||
def _aget_query_embedding(self, query: str) -> list[float]:
|
|
||||||
return [0.1, 0.2, 0.3]
|
|
||||||
|
|
||||||
def _get_query_embedding(self, query: str) -> list[float]:
|
|
||||||
return [0.1, 0.2, 0.3]
|
|
||||||
|
|
||||||
def _get_text_embedding(self, text: str) -> list[float]:
|
|
||||||
return [0.1, 0.2, 0.3]
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_index(mock_settings):
|
|
||||||
with (
|
|
||||||
patch("paperless.ai.indexing.FaissVectorStore.from_persist_dir") as mock_faiss,
|
|
||||||
patch("paperless.ai.indexing.get_embedding_model") as mock_get_embed_model,
|
|
||||||
patch(
|
|
||||||
"paperless.ai.indexing.StorageContext.from_defaults",
|
|
||||||
) as mock_storage_context,
|
|
||||||
patch("paperless.ai.indexing.load_index_from_storage") as mock_load_index,
|
|
||||||
):
|
|
||||||
# Setup mocks
|
|
||||||
mock_vector_store = MagicMock()
|
|
||||||
mock_storage = MagicMock()
|
|
||||||
mock_index = MagicMock()
|
|
||||||
|
|
||||||
mock_faiss.return_value = mock_vector_store
|
|
||||||
mock_storage_context.return_value = mock_storage
|
|
||||||
mock_load_index.return_value = mock_index
|
|
||||||
mock_get_embed_model.return_value = FakeEmbedding()
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = load_index()
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
mock_faiss.assert_called_once_with("/fake/path")
|
|
||||||
mock_get_embed_model.assert_called_once()
|
|
||||||
mock_storage_context.assert_called_once_with(
|
|
||||||
vector_store=mock_vector_store,
|
|
||||||
persist_dir="/fake/path",
|
|
||||||
)
|
|
||||||
mock_load_index.assert_called_once_with(mock_storage)
|
|
||||||
assert result == mock_index
|
|
||||||
|
|
||||||
|
|
||||||
def test_query_similar_documents(mock_document):
|
|
||||||
with (
|
|
||||||
patch("paperless.ai.indexing.load_index") as mock_load_index_func,
|
|
||||||
patch("paperless.ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
|
|
||||||
patch("paperless.ai.indexing.Document.objects.filter") as mock_filter,
|
|
||||||
):
|
|
||||||
# Setup mocks
|
|
||||||
mock_index = MagicMock()
|
|
||||||
mock_load_index_func.return_value = mock_index
|
|
||||||
|
|
||||||
mock_retriever = MagicMock()
|
|
||||||
mock_retriever_cls.return_value = mock_retriever
|
|
||||||
|
|
||||||
mock_node1 = MagicMock()
|
|
||||||
mock_node1.metadata = {"document_id": 1}
|
|
||||||
|
|
||||||
mock_node2 = MagicMock()
|
|
||||||
mock_node2.metadata = {"document_id": 2}
|
|
||||||
|
|
||||||
mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
|
|
||||||
|
|
||||||
mock_filtered_docs = [MagicMock(pk=1), MagicMock(pk=2)]
|
|
||||||
mock_filter.return_value = mock_filtered_docs
|
|
||||||
|
|
||||||
result = query_similar_documents(mock_document, top_k=3)
|
|
||||||
|
|
||||||
mock_load_index_func.assert_called_once()
|
|
||||||
mock_retriever_cls.assert_called_once_with(index=mock_index, similarity_top_k=3)
|
|
||||||
mock_retriever.retrieve.assert_called_once_with(
|
|
||||||
"Test Title\nThis is the document content.",
|
|
||||||
)
|
|
||||||
mock_filter.assert_called_once_with(pk__in=[1, 2])
|
|
||||||
|
|
||||||
assert result == mock_filtered_docs
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user