diff --git a/src/documents/views.py b/src/documents/views.py index 5eec8c43c..043565219 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -756,7 +756,7 @@ class DocumentViewSet( refresh_suggestions_cache(doc.pk) return Response(cached_llm_suggestions.suggestions) - llm_suggestions = get_ai_document_classification(doc) + llm_suggestions = get_ai_document_classification(doc, request.user) matched_tags = match_tags_by_name( llm_suggestions.get("tags", []), diff --git a/src/paperless/ai/ai_classifier.py b/src/paperless/ai/ai_classifier.py index 6b34c0899..33101718d 100644 --- a/src/paperless/ai/ai_classifier.py +++ b/src/paperless/ai/ai_classifier.py @@ -1,9 +1,11 @@ import json import logging +from django.contrib.auth.models import User from llama_index.core.base.llms.types import CompletionResponse from documents.models import Document +from documents.permissions import get_objects_for_user_owner_aware from paperless.ai.client import AIClient from paperless.ai.indexing import query_similar_documents from paperless.config import AIConfig @@ -52,8 +54,8 @@ def build_prompt_without_rag(document: Document) -> str: return prompt -def build_prompt_with_rag(document: Document) -> str: - context = get_context_for_document(document) +def build_prompt_with_rag(document: Document, user: User | None = None) -> str: + context = get_context_for_document(document, user) prompt = build_prompt_without_rag(document) prompt += f""" @@ -65,8 +67,26 @@ def build_prompt_with_rag(document: Document) -> str: return prompt -def get_context_for_document(doc: Document, max_docs: int = 5) -> str: - similar_docs = query_similar_documents(doc)[:max_docs] +def get_context_for_document( + doc: Document, + user: User | None = None, + max_docs: int = 5, +) -> str: + visible_documents = ( + get_objects_for_user_owner_aware( + user, + "view_document", + Document, + ) + if user + else None + ) + similar_docs = query_similar_documents( + document=doc, + document_ids=[document.pk for document in visible_documents] + if visible_documents + else None, + )[:max_docs] context_blocks = [] for similar in similar_docs: text = similar.content or "" @@ -91,11 +111,14 @@ def parse_ai_response(response: CompletionResponse) -> dict: return {} -def get_ai_document_classification(document: Document) -> dict: +def get_ai_document_classification( + document: Document, + user: User | None = None, +) -> dict: ai_config = AIConfig() prompt = ( - build_prompt_with_rag(document) + build_prompt_with_rag(document, user) if ai_config.llm_embedding_backend else build_prompt_without_rag(document) ) diff --git a/src/paperless/ai/indexing.py b/src/paperless/ai/indexing.py index 9a32409ca..3e354ba6d 100644 --- a/src/paperless/ai/indexing.py +++ b/src/paperless/ai/indexing.py @@ -206,12 +206,32 @@ def llm_index_remove_document(document: Document): index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR) -def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]: +def query_similar_documents( + document: Document, + top_k: int = 5, + document_ids: list[int] | None = None, +) -> list[Document]: """ Runs a similarity query and returns top-k similar Document objects. """ index = load_or_build_index() - retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k) + + # constrain only the node(s) that match the document IDs, if given + doc_node_ids = ( + [ + node.node_id + for node in index.docstore.docs.values() + if node.metadata.get("document_id") in document_ids + ] + if document_ids + else None + ) + + retriever = VectorIndexRetriever( + index=index, + similarity_top_k=top_k, + doc_ids=doc_node_ids, + ) query_text = (document.title or "") + "\n" + (document.content or "") results = retriever.retrieve(query_text)