mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-11-04 03:27:12 -05:00 
			
		
		
		
	RAG into suggestions
This commit is contained in:
		
							parent
							
								
									959ebdbb85
								
							
						
					
					
						commit
						aeceaf60a2
					
				@ -3,15 +3,13 @@ import logging
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from documents.models import Document
 | 
					from documents.models import Document
 | 
				
			||||||
from paperless.ai.client import AIClient
 | 
					from paperless.ai.client import AIClient
 | 
				
			||||||
 | 
					from paperless.ai.rag import get_context_for_document
 | 
				
			||||||
 | 
					from paperless.config import AIConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.getLogger("paperless.ai.ai_classifier")
 | 
					logger = logging.getLogger("paperless.ai.rag_classifier")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_ai_document_classification(document: Document) -> dict:
 | 
					def build_prompt_without_rag(document: Document) -> str:
 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Returns classification suggestions for a given document using an LLM.
 | 
					 | 
				
			||||||
    Output schema matches the API's expected DocumentClassificationSuggestions format.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    filename = document.filename or ""
 | 
					    filename = document.filename or ""
 | 
				
			||||||
    content = document.content or ""
 | 
					    content = document.content or ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -41,6 +39,7 @@ def get_ai_document_classification(document: Document) -> dict:
 | 
				
			|||||||
    }}
 | 
					    }}
 | 
				
			||||||
    ---
 | 
					    ---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    FILENAME:
 | 
					    FILENAME:
 | 
				
			||||||
    {filename}
 | 
					    {filename}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -48,39 +47,71 @@ def get_ai_document_classification(document: Document) -> dict:
 | 
				
			|||||||
    {content[:8000]}  # Trim to safe size
 | 
					    {content[:8000]}  # Trim to safe size
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    try:
 | 
					    return prompt
 | 
				
			||||||
        client = AIClient()
 | 
					 | 
				
			||||||
        result = client.run_llm_query(prompt)
 | 
					 | 
				
			||||||
        suggestions = parse_ai_classification_response(result)
 | 
					 | 
				
			||||||
        return suggestions or {}
 | 
					 | 
				
			||||||
    except Exception:
 | 
					 | 
				
			||||||
        logger.exception("Error during LLM classification: %s", exc_info=True)
 | 
					 | 
				
			||||||
        return {}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def parse_ai_classification_response(text: str) -> dict:
 | 
					def build_prompt_with_rag(document: Document) -> str:
 | 
				
			||||||
    """
 | 
					    context = get_context_for_document(document)
 | 
				
			||||||
    Parses LLM output and ensures it conforms to expected schema.
 | 
					    content = document.content or ""
 | 
				
			||||||
 | 
					    filename = document.filename or ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    prompt = f"""
 | 
				
			||||||
 | 
					    You are a helpful assistant that extracts structured information from documents.
 | 
				
			||||||
 | 
					    You have access to similar documents as context to help improve suggestions.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Only output valid JSON in the format below. No additional explanations.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The JSON object must contain:
 | 
				
			||||||
 | 
					    - title: A short, descriptive title
 | 
				
			||||||
 | 
					    - tags: A list of relevant topics
 | 
				
			||||||
 | 
					    - correspondents: People or organizations involved
 | 
				
			||||||
 | 
					    - document_types: Type or category of the document
 | 
				
			||||||
 | 
					    - storage_paths: Suggested folder paths
 | 
				
			||||||
 | 
					    - dates: Up to 3 relevant dates in YYYY-MM-DD
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Here is an example document:
 | 
				
			||||||
 | 
					    FILENAME:
 | 
				
			||||||
 | 
					    {filename}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    CONTENT:
 | 
				
			||||||
 | 
					    {content[:4000]}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    CONTEXT FROM SIMILAR DOCUMENTS:
 | 
				
			||||||
 | 
					    {context[:4000]}
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return prompt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def parse_ai_response(text: str) -> dict:
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        raw = json.loads(text)
 | 
					        raw = json.loads(text)
 | 
				
			||||||
        return {
 | 
					        return {
 | 
				
			||||||
            "title": raw.get("title"),
 | 
					            "title": raw.get("title"),
 | 
				
			||||||
            "tags": raw.get("tags", []),
 | 
					            "tags": raw.get("tags", []),
 | 
				
			||||||
            "correspondents": [raw["correspondents"]]
 | 
					            "correspondents": raw.get("correspondents", []),
 | 
				
			||||||
            if isinstance(raw.get("correspondents"), str)
 | 
					            "document_types": raw.get("document_types", []),
 | 
				
			||||||
            else raw.get("correspondents", []),
 | 
					 | 
				
			||||||
            "document_types": [raw["document_types"]]
 | 
					 | 
				
			||||||
            if isinstance(raw.get("document_types"), str)
 | 
					 | 
				
			||||||
            else raw.get("document_types", []),
 | 
					 | 
				
			||||||
            "storage_paths": raw.get("storage_paths", []),
 | 
					            "storage_paths": raw.get("storage_paths", []),
 | 
				
			||||||
            "dates": [d for d in raw.get("dates", []) if d],
 | 
					            "dates": raw.get("dates", []),
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    except json.JSONDecodeError:
 | 
					    except json.JSONDecodeError:
 | 
				
			||||||
        # fallback: try to extract JSON manually?
 | 
					        logger.exception("Invalid JSON in RAG response")
 | 
				
			||||||
        logger.exception(
 | 
					        return {}
 | 
				
			||||||
            "Failed to parse LLM classification response: %s",
 | 
					
 | 
				
			||||||
            text,
 | 
					
 | 
				
			||||||
            exc_info=True,
 | 
					def get_ai_document_classification(document: Document) -> dict:
 | 
				
			||||||
        )
 | 
					    ai_config = AIConfig()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    prompt = (
 | 
				
			||||||
 | 
					        build_prompt_with_rag(document)
 | 
				
			||||||
 | 
					        if ai_config.llm_embedding_backend
 | 
				
			||||||
 | 
					        else build_prompt_without_rag(document)
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        client = AIClient()
 | 
				
			||||||
 | 
					        result = client.run_llm_query(prompt)
 | 
				
			||||||
 | 
					        return parse_ai_response(result)
 | 
				
			||||||
 | 
					    except Exception:
 | 
				
			||||||
 | 
					        logger.exception("Failed AI classification")
 | 
				
			||||||
        return {}
 | 
					        return {}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										12
									
								
								src/paperless/ai/rag.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								src/paperless/ai/rag.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,12 @@
 | 
				
			|||||||
 | 
					from documents.models import Document
 | 
				
			||||||
 | 
					from paperless.ai.indexing import query_similar_documents
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_context_for_document(doc: Document, max_docs: int = 5) -> str:
 | 
				
			||||||
 | 
					    similar_docs = query_similar_documents(doc)[:max_docs]
 | 
				
			||||||
 | 
					    context_blocks = []
 | 
				
			||||||
 | 
					    for similar in similar_docs:
 | 
				
			||||||
 | 
					        text = similar.content or ""
 | 
				
			||||||
 | 
					        title = similar.title or similar.filename or "Untitled"
 | 
				
			||||||
 | 
					        context_blocks.append(f"TITLE: {title}\n{text}")
 | 
				
			||||||
 | 
					    return "\n\n".join(context_blocks)
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user