mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-11-04 03:27:12 -05:00 
			
		
		
		
	Fixup some tests
This commit is contained in:
		
							parent
							
								
									9183bfc0a4
								
							
						
					
					
						commit
						4a28be233e
					
				@ -65,6 +65,8 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
 | 
				
			|||||||
                "barcode_enable_tag": None,
 | 
					                "barcode_enable_tag": None,
 | 
				
			||||||
                "barcode_tag_mapping": None,
 | 
					                "barcode_tag_mapping": None,
 | 
				
			||||||
                "ai_enabled": False,
 | 
					                "ai_enabled": False,
 | 
				
			||||||
 | 
					                "llm_embedding_backend": None,
 | 
				
			||||||
 | 
					                "llm_embedding_model": None,
 | 
				
			||||||
                "llm_backend": None,
 | 
					                "llm_backend": None,
 | 
				
			||||||
                "llm_model": None,
 | 
					                "llm_model": None,
 | 
				
			||||||
                "llm_api_key": None,
 | 
					                "llm_api_key": None,
 | 
				
			||||||
 | 
				
			|||||||
@ -37,28 +37,65 @@ class OllamaLLM(LLM):
 | 
				
			|||||||
            data = response.json()
 | 
					            data = response.json()
 | 
				
			||||||
            return CompletionResponse(text=data["response"])
 | 
					            return CompletionResponse(text=data["response"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
 | 
				
			||||||
 | 
					        with httpx.Client(timeout=120.0) as client:
 | 
				
			||||||
 | 
					            response = client.post(
 | 
				
			||||||
 | 
					                f"{self.base_url}/api/generate",
 | 
				
			||||||
 | 
					                json={
 | 
				
			||||||
 | 
					                    "model": self.model,
 | 
				
			||||||
 | 
					                    "messages": [
 | 
				
			||||||
 | 
					                        {
 | 
				
			||||||
 | 
					                            "role": message.role,
 | 
				
			||||||
 | 
					                            "content": message.content,
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
 | 
					                        for message in messages
 | 
				
			||||||
 | 
					                    ],
 | 
				
			||||||
 | 
					                    "stream": False,
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            response.raise_for_status()
 | 
				
			||||||
 | 
					            data = response.json()
 | 
				
			||||||
 | 
					            return ChatResponse(text=data["response"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # -- Required stubs for ABC:
 | 
					    # -- Required stubs for ABC:
 | 
				
			||||||
    def stream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen:
 | 
					    def stream_complete(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        prompt: str,
 | 
				
			||||||
 | 
					        **kwargs,
 | 
				
			||||||
 | 
					    ) -> CompletionResponseGen:  # pragma: no cover
 | 
				
			||||||
        raise NotImplementedError("stream_complete not supported")
 | 
					        raise NotImplementedError("stream_complete not supported")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
 | 
					    def stream_chat(
 | 
				
			||||||
        raise NotImplementedError("chat not supported")
 | 
					        self,
 | 
				
			||||||
 | 
					        messages: list[ChatMessage],
 | 
				
			||||||
    def stream_chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponseGen:
 | 
					        **kwargs,
 | 
				
			||||||
 | 
					    ) -> ChatResponseGen:  # pragma: no cover
 | 
				
			||||||
        raise NotImplementedError("stream_chat not supported")
 | 
					        raise NotImplementedError("stream_chat not supported")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def achat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
 | 
					    async def achat(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        messages: list[ChatMessage],
 | 
				
			||||||
 | 
					        **kwargs,
 | 
				
			||||||
 | 
					    ) -> ChatResponse:  # pragma: no cover
 | 
				
			||||||
        raise NotImplementedError("async chat not supported")
 | 
					        raise NotImplementedError("async chat not supported")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def astream_chat(
 | 
					    async def astream_chat(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        messages: list[ChatMessage],
 | 
					        messages: list[ChatMessage],
 | 
				
			||||||
        **kwargs,
 | 
					        **kwargs,
 | 
				
			||||||
    ) -> ChatResponseGen:
 | 
					    ) -> ChatResponseGen:  # pragma: no cover
 | 
				
			||||||
        raise NotImplementedError("async stream_chat not supported")
 | 
					        raise NotImplementedError("async stream_chat not supported")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def acomplete(self, prompt: str, **kwargs) -> CompletionResponse:
 | 
					    async def acomplete(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        prompt: str,
 | 
				
			||||||
 | 
					        **kwargs,
 | 
				
			||||||
 | 
					    ) -> CompletionResponse:  # pragma: no cover
 | 
				
			||||||
        raise NotImplementedError("async complete not supported")
 | 
					        raise NotImplementedError("async complete not supported")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def astream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen:
 | 
					    async def astream_complete(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        prompt: str,
 | 
				
			||||||
 | 
					        **kwargs,
 | 
				
			||||||
 | 
					    ) -> CompletionResponseGen:  # pragma: no cover
 | 
				
			||||||
        raise NotImplementedError("async stream_complete not supported")
 | 
					        raise NotImplementedError("async stream_complete not supported")
 | 
				
			||||||
 | 
				
			|||||||
@ -1419,10 +1419,9 @@ OUTLOOK_OAUTH_ENABLED = bool(
 | 
				
			|||||||
AI_ENABLED = __get_boolean("PAPERLESS_AI_ENABLED", "NO")
 | 
					AI_ENABLED = __get_boolean("PAPERLESS_AI_ENABLED", "NO")
 | 
				
			||||||
LLM_EMBEDDING_BACKEND = os.getenv(
 | 
					LLM_EMBEDDING_BACKEND = os.getenv(
 | 
				
			||||||
    "PAPERLESS_LLM_EMBEDDING_BACKEND",
 | 
					    "PAPERLESS_LLM_EMBEDDING_BACKEND",
 | 
				
			||||||
    "local",
 | 
					)  # "local" or "openai"
 | 
				
			||||||
)  # or "openai"
 | 
					 | 
				
			||||||
LLM_EMBEDDING_MODEL = os.getenv("PAPERLESS_LLM_EMBEDDING_MODEL")
 | 
					LLM_EMBEDDING_MODEL = os.getenv("PAPERLESS_LLM_EMBEDDING_MODEL")
 | 
				
			||||||
LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND", "ollama")  # or "openai"
 | 
					LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND")  # "ollama" or "openai"
 | 
				
			||||||
LLM_MODEL = os.getenv("PAPERLESS_LLM_MODEL")
 | 
					LLM_MODEL = os.getenv("PAPERLESS_LLM_MODEL")
 | 
				
			||||||
LLM_API_KEY = os.getenv("PAPERLESS_LLM_API_KEY")
 | 
					LLM_API_KEY = os.getenv("PAPERLESS_LLM_API_KEY")
 | 
				
			||||||
LLM_URL = os.getenv("PAPERLESS_LLM_URL")
 | 
					LLM_URL = os.getenv("PAPERLESS_LLM_URL")
 | 
				
			||||||
 | 
				
			|||||||
@ -1,11 +1,13 @@
 | 
				
			|||||||
import json
 | 
					import json
 | 
				
			||||||
 | 
					from unittest.mock import MagicMock
 | 
				
			||||||
from unittest.mock import patch
 | 
					from unittest.mock import patch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
 | 
					from django.test import override_settings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from documents.models import Document
 | 
					from documents.models import Document
 | 
				
			||||||
from paperless.ai.ai_classifier import get_ai_document_classification
 | 
					from paperless.ai.ai_classifier import get_ai_document_classification
 | 
				
			||||||
from paperless.ai.ai_classifier import parse_ai_classification_response
 | 
					from paperless.ai.ai_classifier import parse_ai_response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.fixture
 | 
					@pytest.fixture
 | 
				
			||||||
@ -15,8 +17,12 @@ def mock_document():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
@pytest.mark.django_db
 | 
					@pytest.mark.django_db
 | 
				
			||||||
@patch("paperless.ai.client.AIClient.run_llm_query")
 | 
					@patch("paperless.ai.client.AIClient.run_llm_query")
 | 
				
			||||||
 | 
					@override_settings(
 | 
				
			||||||
 | 
					    LLM_BACKEND="ollama",
 | 
				
			||||||
 | 
					    LLM_MODEL="some_model",
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
 | 
					def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
 | 
				
			||||||
    mock_response = json.dumps(
 | 
					    mock_run_llm_query.return_value.text = json.dumps(
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
            "title": "Test Title",
 | 
					            "title": "Test Title",
 | 
				
			||||||
            "tags": ["test", "document"],
 | 
					            "tags": ["test", "document"],
 | 
				
			||||||
@ -26,7 +32,6 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
 | 
				
			|||||||
            "dates": ["2023-01-01"],
 | 
					            "dates": ["2023-01-01"],
 | 
				
			||||||
        },
 | 
					        },
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    mock_run_llm_query.return_value = mock_response
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    result = get_ai_document_classification(mock_document)
 | 
					    result = get_ai_document_classification(mock_document)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -43,58 +48,56 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
 | 
				
			|||||||
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
 | 
					def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
 | 
				
			||||||
    mock_run_llm_query.side_effect = Exception("LLM query failed")
 | 
					    mock_run_llm_query.side_effect = Exception("LLM query failed")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    result = get_ai_document_classification(mock_document)
 | 
					    # assert raises an exception
 | 
				
			||||||
 | 
					    with pytest.raises(Exception):
 | 
				
			||||||
    assert result == {}
 | 
					        get_ai_document_classification(mock_document)
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def test_parse_llm_classification_response_valid():
 | 
					 | 
				
			||||||
    mock_response = json.dumps(
 | 
					 | 
				
			||||||
        {
 | 
					 | 
				
			||||||
            "title": "Test Title",
 | 
					 | 
				
			||||||
            "tags": ["test", "document"],
 | 
					 | 
				
			||||||
            "correspondents": ["John Doe"],
 | 
					 | 
				
			||||||
            "document_types": ["report"],
 | 
					 | 
				
			||||||
            "storage_paths": ["Reports"],
 | 
					 | 
				
			||||||
            "dates": ["2023-01-01"],
 | 
					 | 
				
			||||||
        },
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    result = parse_ai_classification_response(mock_response)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    assert result["title"] == "Test Title"
 | 
					 | 
				
			||||||
    assert result["tags"] == ["test", "document"]
 | 
					 | 
				
			||||||
    assert result["correspondents"] == ["John Doe"]
 | 
					 | 
				
			||||||
    assert result["document_types"] == ["report"]
 | 
					 | 
				
			||||||
    assert result["storage_paths"] == ["Reports"]
 | 
					 | 
				
			||||||
    assert result["dates"] == ["2023-01-01"]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_parse_llm_classification_response_invalid_json():
 | 
					def test_parse_llm_classification_response_invalid_json():
 | 
				
			||||||
    mock_response = "Invalid JSON"
 | 
					    mock_response = MagicMock()
 | 
				
			||||||
 | 
					    mock_response.text = "Invalid JSON response"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    result = parse_ai_classification_response(mock_response)
 | 
					    result = parse_ai_response(mock_response)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assert result == {}
 | 
					    assert result == {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_parse_llm_classification_response_partial_data():
 | 
					@pytest.mark.django_db
 | 
				
			||||||
    mock_response = json.dumps(
 | 
					@patch("paperless.ai.client.AIClient.run_llm_query")
 | 
				
			||||||
        {
 | 
					@patch("paperless.ai.ai_classifier.build_prompt_with_rag")
 | 
				
			||||||
            "title": "Partial Data",
 | 
					@override_settings(
 | 
				
			||||||
            "tags": ["partial"],
 | 
					    LLM_EMBEDDING_BACKEND="local",
 | 
				
			||||||
            "correspondents": "Jane Doe",
 | 
					    LLM_EMBEDDING_MODEL="some_model",
 | 
				
			||||||
            "document_types": "note",
 | 
					    LLM_BACKEND="ollama",
 | 
				
			||||||
            "storage_paths": [],
 | 
					    LLM_MODEL="some_model",
 | 
				
			||||||
            "dates": [],
 | 
					)
 | 
				
			||||||
        },
 | 
					def test_use_rag_if_configured(
 | 
				
			||||||
    )
 | 
					    mock_build_prompt_with_rag,
 | 
				
			||||||
 | 
					    mock_run_llm_query,
 | 
				
			||||||
 | 
					    mock_document,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    mock_build_prompt_with_rag.return_value = "Prompt with RAG"
 | 
				
			||||||
 | 
					    mock_run_llm_query.return_value.text = json.dumps({})
 | 
				
			||||||
 | 
					    get_ai_document_classification(mock_document)
 | 
				
			||||||
 | 
					    mock_build_prompt_with_rag.assert_called_once()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    result = parse_ai_classification_response(mock_response)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assert result["title"] == "Partial Data"
 | 
					@pytest.mark.django_db
 | 
				
			||||||
    assert result["tags"] == ["partial"]
 | 
					@patch("paperless.ai.client.AIClient.run_llm_query")
 | 
				
			||||||
    assert result["correspondents"] == ["Jane Doe"]
 | 
					@patch("paperless.ai.ai_classifier.build_prompt_without_rag")
 | 
				
			||||||
    assert result["document_types"] == ["note"]
 | 
					@patch("paperless.config.AIConfig")
 | 
				
			||||||
    assert result["storage_paths"] == []
 | 
					@override_settings(
 | 
				
			||||||
    assert result["dates"] == []
 | 
					    LLM_BACKEND="ollama",
 | 
				
			||||||
 | 
					    LLM_MODEL="some_model",
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					def test_use_without_rag_if_not_configured(
 | 
				
			||||||
 | 
					    mock_ai_config,
 | 
				
			||||||
 | 
					    mock_build_prompt_without_rag,
 | 
				
			||||||
 | 
					    mock_run_llm_query,
 | 
				
			||||||
 | 
					    mock_document,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    mock_ai_config.llm_embedding_backend = None
 | 
				
			||||||
 | 
					    mock_build_prompt_without_rag.return_value = "Prompt without RAG"
 | 
				
			||||||
 | 
					    mock_run_llm_query.return_value.text = json.dumps({})
 | 
				
			||||||
 | 
					    get_ai_document_classification(mock_document)
 | 
				
			||||||
 | 
					    mock_build_prompt_without_rag.assert_called_once()
 | 
				
			||||||
 | 
				
			|||||||
@ -1,95 +1,93 @@
 | 
				
			|||||||
import json
 | 
					from unittest.mock import MagicMock
 | 
				
			||||||
from unittest.mock import patch
 | 
					from unittest.mock import patch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
from django.conf import settings
 | 
					from llama_index.core.llms import ChatMessage
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from paperless.ai.client import AIClient
 | 
					from paperless.ai.client import AIClient
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.fixture
 | 
					@pytest.fixture
 | 
				
			||||||
def mock_settings():
 | 
					def mock_ai_config():
 | 
				
			||||||
    settings.LLM_BACKEND = "openai"
 | 
					    with patch("paperless.ai.client.AIConfig") as MockAIConfig:
 | 
				
			||||||
    settings.LLM_MODEL = "gpt-3.5-turbo"
 | 
					        mock_config = MagicMock()
 | 
				
			||||||
    settings.LLM_API_KEY = "test-api-key"
 | 
					        MockAIConfig.return_value = mock_config
 | 
				
			||||||
    yield settings
 | 
					        yield mock_config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.django_db
 | 
					@pytest.fixture
 | 
				
			||||||
@patch("paperless.ai.client.AIClient._run_openai_query")
 | 
					def mock_ollama_llm():
 | 
				
			||||||
@patch("paperless.ai.client.AIClient._run_ollama_query")
 | 
					    with patch("paperless.ai.client.OllamaLLM") as MockOllamaLLM:
 | 
				
			||||||
def test_run_llm_query_openai(mock_ollama_query, mock_openai_query, mock_settings):
 | 
					        yield MockOllamaLLM
 | 
				
			||||||
    mock_settings.LLM_BACKEND = "openai"
 | 
					
 | 
				
			||||||
    mock_openai_query.return_value = "OpenAI response"
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def mock_openai_llm():
 | 
				
			||||||
 | 
					    with patch("paperless.ai.client.OpenAI") as MockOpenAI:
 | 
				
			||||||
 | 
					        yield MockOpenAI
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_get_llm_ollama(mock_ai_config, mock_ollama_llm):
 | 
				
			||||||
 | 
					    mock_ai_config.llm_backend = "ollama"
 | 
				
			||||||
 | 
					    mock_ai_config.llm_model = "test_model"
 | 
				
			||||||
 | 
					    mock_ai_config.llm_url = "http://test-url"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    client = AIClient()
 | 
					    client = AIClient()
 | 
				
			||||||
    result = client.run_llm_query("Test prompt")
 | 
					
 | 
				
			||||||
    assert result == "OpenAI response"
 | 
					    mock_ollama_llm.assert_called_once_with(
 | 
				
			||||||
    mock_openai_query.assert_called_once_with("Test prompt")
 | 
					        model="test_model",
 | 
				
			||||||
    mock_ollama_query.assert_not_called()
 | 
					        base_url="http://test-url",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    assert client.llm == mock_ollama_llm.return_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.django_db
 | 
					def test_get_llm_openai(mock_ai_config, mock_openai_llm):
 | 
				
			||||||
@patch("paperless.ai.client.AIClient._run_openai_query")
 | 
					    mock_ai_config.llm_backend = "openai"
 | 
				
			||||||
@patch("paperless.ai.client.AIClient._run_ollama_query")
 | 
					    mock_ai_config.llm_model = "test_model"
 | 
				
			||||||
def test_run_llm_query_ollama(mock_ollama_query, mock_openai_query, mock_settings):
 | 
					    mock_ai_config.openai_api_key = "test_api_key"
 | 
				
			||||||
    mock_settings.LLM_BACKEND = "ollama"
 | 
					
 | 
				
			||||||
    mock_ollama_query.return_value = "Ollama response"
 | 
					 | 
				
			||||||
    client = AIClient()
 | 
					    client = AIClient()
 | 
				
			||||||
    result = client.run_llm_query("Test prompt")
 | 
					
 | 
				
			||||||
    assert result == "Ollama response"
 | 
					    mock_openai_llm.assert_called_once_with(
 | 
				
			||||||
    mock_ollama_query.assert_called_once_with("Test prompt")
 | 
					        model="test_model",
 | 
				
			||||||
    mock_openai_query.assert_not_called()
 | 
					        api_key="test_api_key",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    assert client.llm == mock_openai_llm.return_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.django_db
 | 
					def test_get_llm_unsupported_backend(mock_ai_config):
 | 
				
			||||||
def test_run_llm_query_unsupported_backend(mock_settings):
 | 
					    mock_ai_config.llm_backend = "unsupported"
 | 
				
			||||||
    mock_settings.LLM_BACKEND = "unsupported"
 | 
					
 | 
				
			||||||
    client = AIClient()
 | 
					 | 
				
			||||||
    with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"):
 | 
					    with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"):
 | 
				
			||||||
        client.run_llm_query("Test prompt")
 | 
					        AIClient()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.django_db
 | 
					def test_run_llm_query(mock_ai_config, mock_ollama_llm):
 | 
				
			||||||
def test_run_openai_query(httpx_mock, mock_settings):
 | 
					    mock_ai_config.llm_backend = "ollama"
 | 
				
			||||||
    mock_settings.LLM_BACKEND = "openai"
 | 
					    mock_ai_config.llm_model = "test_model"
 | 
				
			||||||
    httpx_mock.add_response(
 | 
					    mock_ai_config.llm_url = "http://test-url"
 | 
				
			||||||
        url="https://api.openai.com/v1/chat/completions",
 | 
					
 | 
				
			||||||
        json={
 | 
					    mock_llm_instance = mock_ollama_llm.return_value
 | 
				
			||||||
            "choices": [{"message": {"content": "OpenAI response"}}],
 | 
					    mock_llm_instance.complete.return_value = "test_result"
 | 
				
			||||||
        },
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    client = AIClient()
 | 
					    client = AIClient()
 | 
				
			||||||
    result = client.run_llm_query("Test prompt")
 | 
					    result = client.run_llm_query("test_prompt")
 | 
				
			||||||
    assert result == "OpenAI response"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    request = httpx_mock.get_request()
 | 
					    mock_llm_instance.complete.assert_called_once_with("test_prompt")
 | 
				
			||||||
    assert request.method == "POST"
 | 
					    assert result == "test_result"
 | 
				
			||||||
    assert request.headers["Authorization"] == f"Bearer {mock_settings.LLM_API_KEY}"
 | 
					 | 
				
			||||||
    assert request.headers["Content-Type"] == "application/json"
 | 
					 | 
				
			||||||
    assert json.loads(request.content) == {
 | 
					 | 
				
			||||||
        "model": mock_settings.LLM_MODEL,
 | 
					 | 
				
			||||||
        "messages": [{"role": "user", "content": "Test prompt"}],
 | 
					 | 
				
			||||||
        "temperature": 0.3,
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.django_db
 | 
					def test_run_chat(mock_ai_config, mock_ollama_llm):
 | 
				
			||||||
def test_run_ollama_query(httpx_mock, mock_settings):
 | 
					    mock_ai_config.llm_backend = "ollama"
 | 
				
			||||||
    mock_settings.LLM_BACKEND = "ollama"
 | 
					    mock_ai_config.llm_model = "test_model"
 | 
				
			||||||
    httpx_mock.add_response(
 | 
					    mock_ai_config.llm_url = "http://test-url"
 | 
				
			||||||
        url="http://localhost:11434/api/chat",
 | 
					
 | 
				
			||||||
        json={"message": {"content": "Ollama response"}},
 | 
					    mock_llm_instance = mock_ollama_llm.return_value
 | 
				
			||||||
    )
 | 
					    mock_llm_instance.chat.return_value = "test_chat_result"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    client = AIClient()
 | 
					    client = AIClient()
 | 
				
			||||||
    result = client.run_llm_query("Test prompt")
 | 
					    messages = [ChatMessage(role="user", content="Hello")]
 | 
				
			||||||
    assert result == "Ollama response"
 | 
					    result = client.run_chat(messages)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    request = httpx_mock.get_request()
 | 
					    mock_llm_instance.chat.assert_called_once_with(messages)
 | 
				
			||||||
    assert request.method == "POST"
 | 
					    assert result == "test_chat_result"
 | 
				
			||||||
    assert json.loads(request.content) == {
 | 
					 | 
				
			||||||
        "model": mock_settings.LLM_MODEL,
 | 
					 | 
				
			||||||
        "messages": [{"role": "user", "content": "Test prompt"}],
 | 
					 | 
				
			||||||
        "stream": False,
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user