mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-24 23:39:05 -04:00 
			
		
		
		
	Use a frontend config
This commit is contained in:
		
							parent
							
								
									3cbea6cc51
								
							
						
					
					
						commit
						e14f508327
					
				| @ -49,6 +49,7 @@ export enum ConfigOptionType { | |||||||
| export const ConfigCategory = { | export const ConfigCategory = { | ||||||
|   General: $localize`General Settings`, |   General: $localize`General Settings`, | ||||||
|   OCR: $localize`OCR Settings`, |   OCR: $localize`OCR Settings`, | ||||||
|  |   AI: $localize`AI Settings`, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| export interface ConfigOption { | export interface ConfigOption { | ||||||
| @ -180,6 +181,41 @@ export const PaperlessConfigOptions: ConfigOption[] = [ | |||||||
|     config_key: 'PAPERLESS_APP_TITLE', |     config_key: 'PAPERLESS_APP_TITLE', | ||||||
|     category: ConfigCategory.General, |     category: ConfigCategory.General, | ||||||
|   }, |   }, | ||||||
|  |   { | ||||||
|  |     key: 'ai_enabled', | ||||||
|  |     title: $localize`AI Enabled`, | ||||||
|  |     type: ConfigOptionType.Boolean, | ||||||
|  |     config_key: 'PAPERLESS_AI_ENABLED', | ||||||
|  |     category: ConfigCategory.AI, | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |     key: 'llm_backend', | ||||||
|  |     title: $localize`LLM Backend`, | ||||||
|  |     type: ConfigOptionType.String, | ||||||
|  |     config_key: 'PAPERLESS_LLM_BACKEND', | ||||||
|  |     category: ConfigCategory.AI, | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |     key: 'llm_model', | ||||||
|  |     title: $localize`LLM Model`, | ||||||
|  |     type: ConfigOptionType.String, | ||||||
|  |     config_key: 'PAPERLESS_LLM_MODEL', | ||||||
|  |     category: ConfigCategory.AI, | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |     key: 'llm_api_key', | ||||||
|  |     title: $localize`LLM API Key`, | ||||||
|  |     type: ConfigOptionType.String, | ||||||
|  |     config_key: 'PAPERLESS_LLM_API_KEY', | ||||||
|  |     category: ConfigCategory.AI, | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |     key: 'llm_url', | ||||||
|  |     title: $localize`LLM URL`, | ||||||
|  |     type: ConfigOptionType.String, | ||||||
|  |     config_key: 'PAPERLESS_LLM_URL', | ||||||
|  |     category: ConfigCategory.AI, | ||||||
|  |   }, | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
| export interface PaperlessConfig extends ObjectWithId { | export interface PaperlessConfig extends ObjectWithId { | ||||||
| @ -198,4 +234,9 @@ export interface PaperlessConfig extends ObjectWithId { | |||||||
|   user_args: object |   user_args: object | ||||||
|   app_logo: string |   app_logo: string | ||||||
|   app_title: string |   app_title: string | ||||||
|  |   ai_enabled: boolean | ||||||
|  |   llm_backend: string | ||||||
|  |   llm_model: string | ||||||
|  |   llm_api_key: string | ||||||
|  |   llm_url: string | ||||||
| } | } | ||||||
|  | |||||||
| @ -31,7 +31,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase): | |||||||
|         response = self.client.get(self.ENDPOINT, format="json") |         response = self.client.get(self.ENDPOINT, format="json") | ||||||
| 
 | 
 | ||||||
|         self.assertEqual(response.status_code, status.HTTP_200_OK) |         self.assertEqual(response.status_code, status.HTTP_200_OK) | ||||||
| 
 |         self.maxDiff = None | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
|             json.dumps(response.data[0]), |             json.dumps(response.data[0]), | ||||||
|             json.dumps( |             json.dumps( | ||||||
| @ -52,6 +52,11 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase): | |||||||
|                     "color_conversion_strategy": None, |                     "color_conversion_strategy": None, | ||||||
|                     "app_title": None, |                     "app_title": None, | ||||||
|                     "app_logo": None, |                     "app_logo": None, | ||||||
|  |                     "ai_enabled": False, | ||||||
|  |                     "llm_backend": None, | ||||||
|  |                     "llm_model": None, | ||||||
|  |                     "llm_api_key": None, | ||||||
|  |                     "llm_url": None, | ||||||
|                 }, |                 }, | ||||||
|             ), |             ), | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -177,6 +177,7 @@ from paperless.ai.matching import match_document_types_by_name | |||||||
| from paperless.ai.matching import match_storage_paths_by_name | from paperless.ai.matching import match_storage_paths_by_name | ||||||
| from paperless.ai.matching import match_tags_by_name | from paperless.ai.matching import match_tags_by_name | ||||||
| from paperless.celery import app as celery_app | from paperless.celery import app as celery_app | ||||||
|  | from paperless.config import AIConfig | ||||||
| from paperless.config import GeneralConfig | from paperless.config import GeneralConfig | ||||||
| from paperless.db import GnuPG | from paperless.db import GnuPG | ||||||
| from paperless.serialisers import GroupSerializer | from paperless.serialisers import GroupSerializer | ||||||
| @ -738,10 +739,12 @@ class DocumentViewSet( | |||||||
|         ): |         ): | ||||||
|             return HttpResponseForbidden("Insufficient permissions") |             return HttpResponseForbidden("Insufficient permissions") | ||||||
| 
 | 
 | ||||||
|         if settings.AI_ENABLED: |         ai_config = AIConfig() | ||||||
|  | 
 | ||||||
|  |         if ai_config.ai_enabled: | ||||||
|             cached_llm_suggestions = get_llm_suggestion_cache( |             cached_llm_suggestions = get_llm_suggestion_cache( | ||||||
|                 doc.pk, |                 doc.pk, | ||||||
|                 backend=settings.LLM_BACKEND, |                 backend=ai_config.llm_backend, | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|             if cached_llm_suggestions: |             if cached_llm_suggestions: | ||||||
| @ -792,7 +795,7 @@ class DocumentViewSet( | |||||||
|                 "dates": llm_suggestions.get("dates", []), |                 "dates": llm_suggestions.get("dates", []), | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             set_llm_suggestions_cache(doc.pk, resp_data, backend=settings.LLM_BACKEND) |             set_llm_suggestions_cache(doc.pk, resp_data, backend=ai_config.llm_backend) | ||||||
|         else: |         else: | ||||||
|             document_suggestions = get_suggestion_cache(doc.pk) |             document_suggestions = get_suggestion_cache(doc.pk) | ||||||
| 
 | 
 | ||||||
| @ -2220,7 +2223,10 @@ class UiSettingsView(GenericAPIView): | |||||||
|                 request.session["oauth_state"] = manager.state |                 request.session["oauth_state"] = manager.state | ||||||
| 
 | 
 | ||||||
|         ui_settings["email_enabled"] = settings.EMAIL_ENABLED |         ui_settings["email_enabled"] = settings.EMAIL_ENABLED | ||||||
|         ui_settings["ai_enabled"] = settings.AI_ENABLED | 
 | ||||||
|  |         ai_config = AIConfig() | ||||||
|  | 
 | ||||||
|  |         ui_settings["ai_enabled"] = ai_config.ai_enabled | ||||||
| 
 | 
 | ||||||
|         user_resp = { |         user_resp = { | ||||||
|             "id": user.id, |             "id": user.id, | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ import json | |||||||
| import logging | import logging | ||||||
| 
 | 
 | ||||||
| from documents.models import Document | from documents.models import Document | ||||||
| from paperless.ai.client import run_llm_query | from paperless.ai.client import AIClient | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger("paperless.ai.ai_classifier") | logger = logging.getLogger("paperless.ai.ai_classifier") | ||||||
| 
 | 
 | ||||||
| @ -49,7 +49,8 @@ def get_ai_document_classification(document: Document) -> dict: | |||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     try: |     try: | ||||||
|         result = run_llm_query(prompt) |         client = AIClient() | ||||||
|  |         result = client.run_llm_query(prompt) | ||||||
|         suggestions = parse_ai_classification_response(result) |         suggestions = parse_ai_classification_response(result) | ||||||
|         return suggestions or {} |         return suggestions or {} | ||||||
|     except Exception: |     except Exception: | ||||||
|  | |||||||
| @ -1,34 +1,45 @@ | |||||||
| import logging | import logging | ||||||
| 
 | 
 | ||||||
| import httpx | import httpx | ||||||
| from django.conf import settings | 
 | ||||||
|  | from paperless.config import AIConfig | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger("paperless.ai.client") | logger = logging.getLogger("paperless.ai.client") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def run_llm_query(prompt: str) -> str: | class AIClient: | ||||||
|  |     """ | ||||||
|  |     A client for interacting with an LLM backend. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self): | ||||||
|  |         self.settings = AIConfig() | ||||||
|  | 
 | ||||||
|  |     def run_llm_query(self, prompt: str) -> str: | ||||||
|         logger.debug( |         logger.debug( | ||||||
|             "Running LLM query against %s with model %s", |             "Running LLM query against %s with model %s", | ||||||
|         settings.LLM_BACKEND, |             self.settings.llm_backend, | ||||||
|         settings.LLM_MODEL, |             self.settings.llm_model, | ||||||
|         ) |         ) | ||||||
|     match settings.LLM_BACKEND: |         match self.settings.llm_backend: | ||||||
|             case "openai": |             case "openai": | ||||||
|             result = _run_openai_query(prompt) |                 result = self._run_openai_query(prompt) | ||||||
|             case "ollama": |             case "ollama": | ||||||
|             result = _run_ollama_query(prompt) |                 result = self._run_ollama_query(prompt) | ||||||
|             case _: |             case _: | ||||||
|             raise ValueError(f"Unsupported LLM backend: {settings.LLM_BACKEND}") |                 raise ValueError( | ||||||
|  |                     f"Unsupported LLM backend: {self.settings.llm_backend}", | ||||||
|  |                 ) | ||||||
|         logger.debug("LLM query result: %s", result) |         logger.debug("LLM query result: %s", result) | ||||||
|         return result |         return result | ||||||
| 
 | 
 | ||||||
| 
 |     def _run_ollama_query(self, prompt: str) -> str: | ||||||
| def _run_ollama_query(prompt: str) -> str: |         url = self.settings.llm_url or "http://localhost:11434" | ||||||
|         with httpx.Client(timeout=30.0) as client: |         with httpx.Client(timeout=30.0) as client: | ||||||
|             response = client.post( |             response = client.post( | ||||||
|             f"{settings.OLLAMA_URL}/api/chat", |                 f"{url}/api/chat", | ||||||
|                 json={ |                 json={ | ||||||
|                 "model": settings.LLM_MODEL, |                     "model": self.settings.llm_model, | ||||||
|                     "messages": [{"role": "user", "content": prompt}], |                     "messages": [{"role": "user", "content": prompt}], | ||||||
|                     "stream": False, |                     "stream": False, | ||||||
|                 }, |                 }, | ||||||
| @ -36,20 +47,21 @@ def _run_ollama_query(prompt: str) -> str: | |||||||
|             response.raise_for_status() |             response.raise_for_status() | ||||||
|             return response.json()["message"]["content"] |             return response.json()["message"]["content"] | ||||||
| 
 | 
 | ||||||
| 
 |     def _run_openai_query(self, prompt: str) -> str: | ||||||
| def _run_openai_query(prompt: str) -> str: |         if not self.settings.llm_api_key: | ||||||
|     if not settings.LLM_API_KEY: |  | ||||||
|             raise RuntimeError("PAPERLESS_LLM_API_KEY is not set") |             raise RuntimeError("PAPERLESS_LLM_API_KEY is not set") | ||||||
| 
 | 
 | ||||||
|  |         url = self.settings.llm_url or "https://api.openai.com" | ||||||
|  | 
 | ||||||
|         with httpx.Client(timeout=30.0) as client: |         with httpx.Client(timeout=30.0) as client: | ||||||
|             response = client.post( |             response = client.post( | ||||||
|             f"{settings.OPENAI_URL}/v1/chat/completions", |                 f"{url}/v1/chat/completions", | ||||||
|                 headers={ |                 headers={ | ||||||
|                 "Authorization": f"Bearer {settings.LLM_API_KEY}", |                     "Authorization": f"Bearer {self.settings.llm_api_key}", | ||||||
|                     "Content-Type": "application/json", |                     "Content-Type": "application/json", | ||||||
|                 }, |                 }, | ||||||
|                 json={ |                 json={ | ||||||
|                 "model": settings.LLM_MODEL, |                     "model": self.settings.llm_model, | ||||||
|                     "messages": [{"role": "user", "content": prompt}], |                     "messages": [{"role": "user", "content": prompt}], | ||||||
|                     "temperature": 0.3, |                     "temperature": 0.3, | ||||||
|                 }, |                 }, | ||||||
|  | |||||||
| @ -114,3 +114,25 @@ class GeneralConfig(BaseConfig): | |||||||
| 
 | 
 | ||||||
|         self.app_title = app_config.app_title or None |         self.app_title = app_config.app_title or None | ||||||
|         self.app_logo = app_config.app_logo.url if app_config.app_logo else None |         self.app_logo = app_config.app_logo.url if app_config.app_logo else None | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @dataclasses.dataclass | ||||||
|  | class AIConfig(BaseConfig): | ||||||
|  |     """ | ||||||
|  |     AI related settings that require global scope | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     ai_enabled: bool = dataclasses.field(init=False) | ||||||
|  |     llm_backend: str = dataclasses.field(init=False) | ||||||
|  |     llm_model: str = dataclasses.field(init=False) | ||||||
|  |     llm_api_key: str = dataclasses.field(init=False) | ||||||
|  |     llm_url: str = dataclasses.field(init=False) | ||||||
|  | 
 | ||||||
|  |     def __post_init__(self) -> None: | ||||||
|  |         app_config = self._get_config_instance() | ||||||
|  | 
 | ||||||
|  |         self.ai_enabled = app_config.ai_enabled or settings.AI_ENABLED | ||||||
|  |         self.llm_backend = app_config.llm_backend or settings.LLM_BACKEND | ||||||
|  |         self.llm_model = app_config.llm_model or settings.LLM_MODEL | ||||||
|  |         self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY | ||||||
|  |         self.llm_url = app_config.llm_url or settings.LLM_URL | ||||||
|  | |||||||
| @ -0,0 +1,63 @@ | |||||||
|  | # Generated by Django 5.1.7 on 2025-04-24 02:09 | ||||||
|  | 
 | ||||||
|  | from django.db import migrations | ||||||
|  | from django.db import models | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |     dependencies = [ | ||||||
|  |         ("paperless", "0003_alter_applicationconfiguration_max_image_pixels"), | ||||||
|  |     ] | ||||||
|  | 
 | ||||||
|  |     operations = [ | ||||||
|  |         migrations.AddField( | ||||||
|  |             model_name="applicationconfiguration", | ||||||
|  |             name="ai_enabled", | ||||||
|  |             field=models.BooleanField( | ||||||
|  |                 default=False, | ||||||
|  |                 null=True, | ||||||
|  |                 verbose_name="Enables AI features", | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |         migrations.AddField( | ||||||
|  |             model_name="applicationconfiguration", | ||||||
|  |             name="llm_api_key", | ||||||
|  |             field=models.CharField( | ||||||
|  |                 blank=True, | ||||||
|  |                 max_length=128, | ||||||
|  |                 null=True, | ||||||
|  |                 verbose_name="Sets the LLM API key", | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |         migrations.AddField( | ||||||
|  |             model_name="applicationconfiguration", | ||||||
|  |             name="llm_backend", | ||||||
|  |             field=models.CharField( | ||||||
|  |                 blank=True, | ||||||
|  |                 choices=[("openai", "OpenAI"), ("ollama", "Ollama")], | ||||||
|  |                 max_length=32, | ||||||
|  |                 null=True, | ||||||
|  |                 verbose_name="Sets the LLM backend", | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |         migrations.AddField( | ||||||
|  |             model_name="applicationconfiguration", | ||||||
|  |             name="llm_model", | ||||||
|  |             field=models.CharField( | ||||||
|  |                 blank=True, | ||||||
|  |                 max_length=32, | ||||||
|  |                 null=True, | ||||||
|  |                 verbose_name="Sets the LLM model", | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |         migrations.AddField( | ||||||
|  |             model_name="applicationconfiguration", | ||||||
|  |             name="llm_url", | ||||||
|  |             field=models.CharField( | ||||||
|  |                 blank=True, | ||||||
|  |                 max_length=128, | ||||||
|  |                 null=True, | ||||||
|  |                 verbose_name="Sets the LLM URL, optional", | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
| @ -74,6 +74,15 @@ class ColorConvertChoices(models.TextChoices): | |||||||
|     CMYK = ("CMYK", _("CMYK")) |     CMYK = ("CMYK", _("CMYK")) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class LLMBackend(models.TextChoices): | ||||||
|  |     """ | ||||||
|  |     Matches to --llm-backend | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     OPENAI = ("openai", _("OpenAI")) | ||||||
|  |     OLLAMA = ("ollama", _("Ollama")) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class ApplicationConfiguration(AbstractSingletonModel): | class ApplicationConfiguration(AbstractSingletonModel): | ||||||
|     """ |     """ | ||||||
|     Settings which are common across more than 1 parser |     Settings which are common across more than 1 parser | ||||||
| @ -184,6 +193,45 @@ class ApplicationConfiguration(AbstractSingletonModel): | |||||||
|         upload_to="logo/", |         upload_to="logo/", | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|  |     """ | ||||||
|  |     AI related settings | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     ai_enabled = models.BooleanField( | ||||||
|  |         verbose_name=_("Enables AI features"), | ||||||
|  |         null=True, | ||||||
|  |         default=False, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     llm_backend = models.CharField( | ||||||
|  |         verbose_name=_("Sets the LLM backend"), | ||||||
|  |         null=True, | ||||||
|  |         blank=True, | ||||||
|  |         max_length=32, | ||||||
|  |         choices=LLMBackend.choices, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     llm_model = models.CharField( | ||||||
|  |         verbose_name=_("Sets the LLM model"), | ||||||
|  |         null=True, | ||||||
|  |         blank=True, | ||||||
|  |         max_length=32, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     llm_api_key = models.CharField( | ||||||
|  |         verbose_name=_("Sets the LLM API key"), | ||||||
|  |         null=True, | ||||||
|  |         blank=True, | ||||||
|  |         max_length=128, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     llm_url = models.CharField( | ||||||
|  |         verbose_name=_("Sets the LLM URL, optional"), | ||||||
|  |         null=True, | ||||||
|  |         blank=True, | ||||||
|  |         max_length=128, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|     class Meta: |     class Meta: | ||||||
|         verbose_name = _("paperless application settings") |         verbose_name = _("paperless application settings") | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -1275,5 +1275,4 @@ AI_ENABLED = __get_boolean("PAPERLESS_AI_ENABLED", "NO") | |||||||
| LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND", "openai")  # or "ollama" | LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND", "openai")  # or "ollama" | ||||||
| LLM_MODEL = os.getenv("PAPERLESS_LLM_MODEL") | LLM_MODEL = os.getenv("PAPERLESS_LLM_MODEL") | ||||||
| LLM_API_KEY = os.getenv("PAPERLESS_LLM_API_KEY") | LLM_API_KEY = os.getenv("PAPERLESS_LLM_API_KEY") | ||||||
| OPENAI_URL = os.getenv("PAPERLESS_OPENAI_URL", "https://api.openai.com") | LLM_URL = os.getenv("PAPERLESS_LLM_URL") | ||||||
| OLLAMA_URL = os.getenv("PAPERLESS_OLLAMA_URL", "http://localhost:11434") |  | ||||||
|  | |||||||
| @ -13,7 +13,8 @@ def mock_document(): | |||||||
|     return Document(filename="test.pdf", content="This is a test document content.") |     return Document(filename="test.pdf", content="This is a test document content.") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @patch("paperless.ai.ai_classifier.run_llm_query") | @pytest.mark.django_db | ||||||
|  | @patch("paperless.ai.client.AIClient.run_llm_query") | ||||||
| def test_get_ai_document_classification_success(mock_run_llm_query, mock_document): | def test_get_ai_document_classification_success(mock_run_llm_query, mock_document): | ||||||
|     mock_response = json.dumps( |     mock_response = json.dumps( | ||||||
|         { |         { | ||||||
| @ -37,7 +38,8 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen | |||||||
|     assert result["dates"] == ["2023-01-01"] |     assert result["dates"] == ["2023-01-01"] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @patch("paperless.ai.ai_classifier.run_llm_query") | @pytest.mark.django_db | ||||||
|  | @patch("paperless.ai.client.AIClient.run_llm_query") | ||||||
| def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document): | def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document): | ||||||
|     mock_run_llm_query.side_effect = Exception("LLM query failed") |     mock_run_llm_query.side_effect = Exception("LLM query failed") | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -4,9 +4,7 @@ from unittest.mock import patch | |||||||
| import pytest | import pytest | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| 
 | 
 | ||||||
| from paperless.ai.client import _run_ollama_query | from paperless.ai.client import AIClient | ||||||
| from paperless.ai.client import _run_openai_query |  | ||||||
| from paperless.ai.client import run_llm_query |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
| @ -14,52 +12,59 @@ def mock_settings(): | |||||||
|     settings.LLM_BACKEND = "openai" |     settings.LLM_BACKEND = "openai" | ||||||
|     settings.LLM_MODEL = "gpt-3.5-turbo" |     settings.LLM_MODEL = "gpt-3.5-turbo" | ||||||
|     settings.LLM_API_KEY = "test-api-key" |     settings.LLM_API_KEY = "test-api-key" | ||||||
|     settings.OPENAI_URL = "https://api.openai.com" |  | ||||||
|     settings.OLLAMA_URL = "https://ollama.example.com" |  | ||||||
|     yield settings |     yield settings | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @patch("paperless.ai.client._run_openai_query") | @pytest.mark.django_db | ||||||
| @patch("paperless.ai.client._run_ollama_query") | @patch("paperless.ai.client.AIClient._run_openai_query") | ||||||
|  | @patch("paperless.ai.client.AIClient._run_ollama_query") | ||||||
| def test_run_llm_query_openai(mock_ollama_query, mock_openai_query, mock_settings): | def test_run_llm_query_openai(mock_ollama_query, mock_openai_query, mock_settings): | ||||||
|  |     mock_settings.LLM_BACKEND = "openai" | ||||||
|     mock_openai_query.return_value = "OpenAI response" |     mock_openai_query.return_value = "OpenAI response" | ||||||
|     result = run_llm_query("Test prompt") |     client = AIClient() | ||||||
|  |     result = client.run_llm_query("Test prompt") | ||||||
|     assert result == "OpenAI response" |     assert result == "OpenAI response" | ||||||
|     mock_openai_query.assert_called_once_with("Test prompt") |     mock_openai_query.assert_called_once_with("Test prompt") | ||||||
|     mock_ollama_query.assert_not_called() |     mock_ollama_query.assert_not_called() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @patch("paperless.ai.client._run_openai_query") | @pytest.mark.django_db | ||||||
| @patch("paperless.ai.client._run_ollama_query") | @patch("paperless.ai.client.AIClient._run_openai_query") | ||||||
|  | @patch("paperless.ai.client.AIClient._run_ollama_query") | ||||||
| def test_run_llm_query_ollama(mock_ollama_query, mock_openai_query, mock_settings): | def test_run_llm_query_ollama(mock_ollama_query, mock_openai_query, mock_settings): | ||||||
|     mock_settings.LLM_BACKEND = "ollama" |     mock_settings.LLM_BACKEND = "ollama" | ||||||
|     mock_ollama_query.return_value = "Ollama response" |     mock_ollama_query.return_value = "Ollama response" | ||||||
|     result = run_llm_query("Test prompt") |     client = AIClient() | ||||||
|  |     result = client.run_llm_query("Test prompt") | ||||||
|     assert result == "Ollama response" |     assert result == "Ollama response" | ||||||
|     mock_ollama_query.assert_called_once_with("Test prompt") |     mock_ollama_query.assert_called_once_with("Test prompt") | ||||||
|     mock_openai_query.assert_not_called() |     mock_openai_query.assert_not_called() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @pytest.mark.django_db | ||||||
| def test_run_llm_query_unsupported_backend(mock_settings): | def test_run_llm_query_unsupported_backend(mock_settings): | ||||||
|     mock_settings.LLM_BACKEND = "unsupported" |     mock_settings.LLM_BACKEND = "unsupported" | ||||||
|  |     client = AIClient() | ||||||
|     with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"): |     with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"): | ||||||
|         run_llm_query("Test prompt") |         client.run_llm_query("Test prompt") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @pytest.mark.django_db | ||||||
| def test_run_openai_query(httpx_mock, mock_settings): | def test_run_openai_query(httpx_mock, mock_settings): | ||||||
|  |     mock_settings.LLM_BACKEND = "openai" | ||||||
|     httpx_mock.add_response( |     httpx_mock.add_response( | ||||||
|         url=f"{mock_settings.OPENAI_URL}/v1/chat/completions", |         url="https://api.openai.com/v1/chat/completions", | ||||||
|         json={ |         json={ | ||||||
|             "choices": [{"message": {"content": "OpenAI response"}}], |             "choices": [{"message": {"content": "OpenAI response"}}], | ||||||
|         }, |         }, | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     result = _run_openai_query("Test prompt") |     client = AIClient() | ||||||
|  |     result = client.run_llm_query("Test prompt") | ||||||
|     assert result == "OpenAI response" |     assert result == "OpenAI response" | ||||||
| 
 | 
 | ||||||
|     request = httpx_mock.get_request() |     request = httpx_mock.get_request() | ||||||
|     assert request.method == "POST" |     assert request.method == "POST" | ||||||
|     assert request.url == f"{mock_settings.OPENAI_URL}/v1/chat/completions" |  | ||||||
|     assert request.headers["Authorization"] == f"Bearer {mock_settings.LLM_API_KEY}" |     assert request.headers["Authorization"] == f"Bearer {mock_settings.LLM_API_KEY}" | ||||||
|     assert request.headers["Content-Type"] == "application/json" |     assert request.headers["Content-Type"] == "application/json" | ||||||
|     assert json.loads(request.content) == { |     assert json.loads(request.content) == { | ||||||
| @ -69,18 +74,20 @@ def test_run_openai_query(httpx_mock, mock_settings): | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @pytest.mark.django_db | ||||||
| def test_run_ollama_query(httpx_mock, mock_settings): | def test_run_ollama_query(httpx_mock, mock_settings): | ||||||
|  |     mock_settings.LLM_BACKEND = "ollama" | ||||||
|     httpx_mock.add_response( |     httpx_mock.add_response( | ||||||
|         url=f"{mock_settings.OLLAMA_URL}/api/chat", |         url="http://localhost:11434/api/chat", | ||||||
|         json={"message": {"content": "Ollama response"}}, |         json={"message": {"content": "Ollama response"}}, | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     result = _run_ollama_query("Test prompt") |     client = AIClient() | ||||||
|  |     result = client.run_llm_query("Test prompt") | ||||||
|     assert result == "Ollama response" |     assert result == "Ollama response" | ||||||
| 
 | 
 | ||||||
|     request = httpx_mock.get_request() |     request = httpx_mock.get_request() | ||||||
|     assert request.method == "POST" |     assert request.method == "POST" | ||||||
|     assert request.url == f"{mock_settings.OLLAMA_URL}/api/chat" |  | ||||||
|     assert json.loads(request.content) == { |     assert json.loads(request.content) == { | ||||||
|         "model": mock_settings.LLM_MODEL, |         "model": mock_settings.LLM_MODEL, | ||||||
|         "messages": [{"role": "user", "content": "Test prompt"}], |         "messages": [{"role": "user", "content": "Test prompt"}], | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user