From f84f3ca7bafbb759bd31a744eca9942769801c3c Mon Sep 17 00:00:00 2001 From: lazydao Date: Tue, 14 Apr 2026 10:24:18 +0800 Subject: [PATCH] add OpenAI-compatible AI provider --- src/calibre/ai/openai_compatible/__init__.py | 20 ++ src/calibre/ai/openai_compatible/backend.py | 215 +++++++++++++++++++ src/calibre/ai/openai_compatible/config.py | 199 +++++++++++++++++ src/calibre/customize/builtins.py | 3 +- 4 files changed, 436 insertions(+), 1 deletion(-) create mode 100644 src/calibre/ai/openai_compatible/__init__.py create mode 100644 src/calibre/ai/openai_compatible/backend.py create mode 100644 src/calibre/ai/openai_compatible/config.py diff --git a/src/calibre/ai/openai_compatible/__init__.py b/src/calibre/ai/openai_compatible/__init__.py new file mode 100644 index 0000000000..8bc003fb38 --- /dev/null +++ b/src/calibre/ai/openai_compatible/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# License: GPLv3 Copyright: 2026, OpenAI + +from calibre.customize import AIProviderPlugin + + +class OpenAICompatible(AIProviderPlugin): + name = 'OpenAI compatible' + version = (1, 0, 0) + description = _( + 'Generic OpenAI compatible AI services. Use this to connect calibre to self-hosted or third-party services' + ' that implement the OpenAI chat completions API.' + ) + author = 'OpenAI' + builtin_live_module_name = 'calibre.ai.openai_compatible.backend' + + @property + def capabilities(self): + from calibre.ai import AICapabilities + return AICapabilities.text_to_text diff --git a/src/calibre/ai/openai_compatible/backend.py b/src/calibre/ai/openai_compatible/backend.py new file mode 100644 index 0000000000..1beb4c789f --- /dev/null +++ b/src/calibre/ai/openai_compatible/backend.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python +# License: GPLv3 Copyright: 2026, OpenAI + +import json +import posixpath +from collections.abc import Iterable, Iterator, Sequence +from functools import lru_cache +from typing import Any, NamedTuple +from urllib.parse import urlparse, urlunparse +from urllib.request import Request + +from calibre.ai import ChatMessage, ChatMessageType, ChatResponse, ResultBlocked, ResultBlockReason +from calibre.ai.openai_compatible import OpenAICompatible +from calibre.ai.prefs import decode_secret, pref_for_provider +from calibre.ai.utils import chat_with_error_handler, develop_text_chat, download_data, read_streaming_response + +module_version = 1 + + +def pref(key: str, defval: Any = None) -> Any: + return pref_for_provider(OpenAICompatible.name, key, defval) + + +def is_ready_for_use() -> bool: + return bool(pref('api_url') and pref('text_model')) + + +class Model(NamedTuple): + id: str + owner: str + + @classmethod + def from_dict(cls, x: dict[str, Any]) -> 'Model': + return cls(id=x['id'], owner=x.get('owned_by', 'remote')) + + +def api_url(path: str = '', use_api_url: str | None = None) -> str: + base = (pref('api_url') if use_api_url is None else use_api_url) or '' + purl = urlparse(base) + base_path = (purl.path or '').rstrip('/') + if not base_path: + base_path = '/v1' + elif base_path.endswith('/chat/completions'): + base_path = base_path[:-len('/chat/completions')] + elif base_path.endswith('/models'): + base_path = base_path[:-len('/models')] + if path: + base_path = posixpath.join(base_path, path) + purl = purl._replace(path=base_path) + return urlunparse(purl) + + +def raw_api_key(use_api_key: str | None = None) -> str: + key = pref('api_key') if use_api_key is None else use_api_key + return decode_secret(key) if key else '' + + +@lru_cache(32) +def request_headers( + use_api_key: str | None = None, use_headers: Sequence[tuple[str, str]] | None = None +) -> tuple[tuple[str, str], ...]: + ans = [('Content-Type', 'application/json')] + extra_headers = pref('headers', ()) if use_headers is None else use_headers + extra_headers = tuple(extra_headers or ()) + has_auth = False + for key, val in extra_headers: + if key.lower() == 'authorization': + has_auth = True + ans.append((key, val)) + if api_key := raw_api_key(use_api_key): + if not has_auth: + ans.insert(0, ('Authorization', f'Bearer {api_key}')) + return tuple(ans) + + +@lru_cache(8) +def get_available_models( + use_api_url: str | None = None, use_api_key: str | None = None, use_headers: Sequence[tuple[str, str]] | None = None +) -> dict[str, Model]: + url = api_url('models', use_api_url) + data = json.loads(download_data(url, request_headers(use_api_key, use_headers))) + ans = {} + if 'data' in data: + for model_data in data['data']: + model = Model.from_dict(model_data) + ans[model.id] = model + return ans + + +def human_readable_model_name(model_id: str) -> str: + return model_id + + +def config_widget(): + from calibre.ai.openai_compatible.config import ConfigWidget + return ConfigWidget() + + +def save_settings(config_widget): + config_widget.save_settings() + + +def for_assistant(self: ChatMessage) -> dict[str, Any]: + if self.type not in (ChatMessageType.assistant, ChatMessageType.system, ChatMessageType.user, ChatMessageType.developer): + raise ValueError(f'Unsupported message type: {self.type}') + return {'role': self.type.value, 'content': self.query} + + +def coerce_text(value: Any) -> str: + if isinstance(value, str): + return value + if isinstance(value, dict): + text = value.get('text') + return text if isinstance(text, str) else '' + if isinstance(value, list): + return ''.join(filter(None, map(coerce_text, value))) + return '' + + +def chat_request( + data: dict[str, Any], url_override: str | None = None, use_api_key: str | None = None, + use_headers: Sequence[tuple[str, str]] | None = None +) -> Request: + url = api_url('chat/completions', url_override) + return Request(url, data=json.dumps(data).encode('utf-8'), headers=dict(request_headers(use_api_key, use_headers)), method='POST') + + +def as_chat_responses(d: dict[str, Any], model_id: str) -> Iterator[ChatResponse]: + blocked = False + for choice in d.get('choices', ()): + delta = choice.get('delta') or {} + content = coerce_text(delta.get('content')) + reasoning = coerce_text(delta.get('reasoning_content')) + role = delta.get('role') or 'assistant' + if content or reasoning: + yield ChatResponse( + content=content, reasoning=reasoning, type=ChatMessageType(role), plugin_name=OpenAICompatible.name + ) + if choice.get('finish_reason') == 'content_filter': + blocked = True + if blocked: + yield ChatResponse(exception=ResultBlocked(ResultBlockReason.safety), plugin_name=OpenAICompatible.name) + return + if usage := d.get('usage'): + yield ChatResponse( + has_metadata=True, provider=OpenAICompatible.name, model=d.get('model') or model_id, + plugin_name=OpenAICompatible.name + ) + + +def text_chat_implementation(messages: Iterable[ChatMessage], use_model: str = '') -> Iterator[ChatResponse]: + model_id = use_model or pref('text_model') + temperature = pref('temperature', 0.7) + data = { + 'model': model_id, + 'messages': [for_assistant(m) for m in messages], + 'stream': True, + 'temperature': temperature, + } + rq = chat_request(data) + seen_metadata = False + for data in read_streaming_response(rq, OpenAICompatible.name, timeout=pref('timeout', 120)): + for response in as_chat_responses(data, model_id): + if response.has_metadata: + seen_metadata = True + yield response + if response.exception: + return + if not seen_metadata: + yield ChatResponse(has_metadata=True, provider=OpenAICompatible.name, model=model_id, plugin_name=OpenAICompatible.name) + + +def text_chat(messages: Iterable[ChatMessage], use_model: str = '') -> Iterator[ChatResponse]: + yield from chat_with_error_handler(text_chat_implementation(messages, use_model)) + + +def develop(use_model: str = '', msg: str = '') -> None: + m = (ChatMessage(msg),) if msg else () + develop_text_chat(text_chat, use_model, messages=m) + + +def find_tests(): + import unittest + + class TestOpenAICompatibleBackend(unittest.TestCase): + + def test_api_url_normalization(self): + self.assertEqual(api_url('models', 'http://localhost:1234'), 'http://localhost:1234/v1/models') + self.assertEqual(api_url('models', 'http://localhost:1234/v1'), 'http://localhost:1234/v1/models') + self.assertEqual(api_url('models', 'https://example.com/custom/api'), 'https://example.com/custom/api/models') + self.assertEqual(api_url('chat/completions', 'https://ark.cn-beijing.volces.com/api/v3'), 'https://ark.cn-beijing.volces.com/api/v3/chat/completions') + self.assertEqual(api_url('chat/completions', 'https://ark.cn-beijing.volces.com/api/v3/chat/completions'), 'https://ark.cn-beijing.volces.com/api/v3/chat/completions') + + def test_request_headers_allows_missing_headers_pref(self): + headers = request_headers() + self.assertEqual(headers, (('Content-Type', 'application/json'),)) + + def test_parsing_stream_deltas(self): + responses = tuple(as_chat_responses({ + 'model': 'demo-model', + 'choices': [ + {'delta': {'role': 'assistant', 'content': 'Hello', 'reasoning_content': 'Think'}, 'finish_reason': None} + ], + 'usage': {'total_tokens': 42}, + }, 'demo-model')) + self.assertEqual(responses[0].content, 'Hello') + self.assertEqual(responses[0].reasoning, 'Think') + self.assertTrue(responses[-1].has_metadata) + self.assertEqual(responses[-1].model, 'demo-model') + + return unittest.defaultTestLoader.loadTestsFromTestCase(TestOpenAICompatibleBackend) + + +if __name__ == '__main__': + develop() diff --git a/src/calibre/ai/openai_compatible/config.py b/src/calibre/ai/openai_compatible/config.py new file mode 100644 index 0000000000..05abb0569f --- /dev/null +++ b/src/calibre/ai/openai_compatible/config.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python +# License: GPLv3 Copyright: 2026, OpenAI + +from functools import partial +from typing import Any + +from qt.core import ( + QComboBox, + QCompleter, + QDoubleSpinBox, + QFormLayout, + QHBoxLayout, + QLabel, + QLineEdit, + QListView, + QPlainTextEdit, + QPushButton, + QSpinBox, + Qt, + QWidget, +) + +from calibre.ai.openai_compatible import OpenAICompatible +from calibre.ai.prefs import decode_secret, encode_secret, pref_for_provider, set_prefs_for_provider +from calibre.ai.utils import configure, plugin_for_name +from calibre.gui2 import error_dialog +from calibre.gui2.widgets import BusyCursor + +pref = partial(pref_for_provider, OpenAICompatible.name) + + +class ConfigWidget(QWidget): + + def __init__(self, parent: QWidget | None = None): + super().__init__(parent) + l = QFormLayout(self) + l.setFieldGrowthPolicy(QFormLayout.FieldGrowthPolicy.ExpandingFieldsGrow) + + la = QLabel('

' + _( + 'Connect calibre to any self-hosted or third-party service that implements the OpenAI compatible' + ' /v1/chat/completions API. This is useful for gateways, local servers and other' + ' services that are not listed as dedicated providers.' + )) + la.setWordWrap(True) + l.addRow(la) + + self.api_url_edit = a = QLineEdit(self) + a.setClearButtonEnabled(True) + a.setPlaceholderText(_('For example: {}').format('https://example.com/v1')) + l.addRow(_('API &URL:'), a) + a.setText(pref('api_url') or '') + + self.api_key_edit = ak = QLineEdit(self) + ak.setClearButtonEnabled(True) + ak.setPlaceholderText(_('Optional. Sent as Authorization: Bearer ')) + l.addRow(_('API &key:'), ak) + if key := pref('api_key'): + ak.setText(decode_secret(key)) + + self.headers_edit = he = QPlainTextEdit(self) + he.setPlaceholderText(_('Optional HTTP headers, one per line, in the format: Header-Name: Value')) + l.addRow(_('HTTP &headers:'), he) + he.setPlainText('\n'.join(f'{k}: {v}' for (k, v) in pref('headers') or ())) + + self.timeout_sb = t = QSpinBox(self) + t.setRange(15, 600) + t.setSingleStep(1) + t.setSuffix(_(' seconds')) + t.setValue(pref('timeout', 120)) + l.addRow(_('&Timeout:'), t) + + self.temp_sb = temp = QDoubleSpinBox(self) + temp.setRange(0.0, 2.0) + temp.setSingleStep(0.1) + temp.setValue(pref('temperature', 0.7)) + temp.setToolTip(_('Controls randomness. Lower values are more deterministic.')) + l.addRow(_('T&emperature:'), temp) + + w = QWidget(self) + h = QHBoxLayout(w) + h.setContentsMargins(0, 0, 0, 0) + + self.model_combo = mc = QComboBox(w) + mc.setEditable(True) + mc.setInsertPolicy(QComboBox.NoInsert) + mc.setView(QListView(mc)) + mc.setSizeAdjustPolicy(QComboBox.AdjustToContentsOnFirstShow) + completer = QCompleter(mc) + completer.setCaseSensitivity(Qt.CaseInsensitive) + mc.setCompleter(completer) + + if saved_model := pref('text_model') or '': + mc.addItem(saved_model) + mc.setCurrentText(saved_model) + + self.refresh_btn = rb = QPushButton(_('&Refresh models'), w) + rb.clicked.connect(self.refresh_models) + h.addWidget(mc, stretch=10) + h.addWidget(rb) + l.addRow(_('Model for &text tasks:'), w) + + self.model_status = ms = QLabel('', self) + ms.setWordWrap(True) + ms.setTextInteractionFlags(Qt.TextInteractionFlag.TextSelectableByMouse) + l.addRow('', ms) + + def refresh_models(self): + with BusyCursor(): + try: + plugin = plugin_for_name(OpenAICompatible.name) + backend = plugin.builtin_live_module + backend.get_available_models.cache_clear() + encoded_key = encode_secret(self.api_key) if self.api_key else '' + models_dict = backend.get_available_models(self.api_url, encoded_key, self.headers) + current_text = self.text_model + model_ids = sorted(models_dict, key=lambda x: x.lower()) + self.model_combo.blockSignals(True) + self.model_combo.clear() + for model_id in model_ids: + self.model_combo.addItem(model_id) + self.model_combo.setCurrentText(current_text or (model_ids[0] if model_ids else '')) + self.model_combo.blockSignals(False) + if model_ids: + sample = ', '.join(model_ids[:3]) + msg = _('Found {} models. e.g.: {}').format(len(model_ids), sample) + if len(model_ids) > 3: + msg += _(' (and more)') + self.model_status.setText(msg) + self.model_status.setToolTip('\n'.join(model_ids)) + else: + self.model_status.setText(_('The server responded, but returned no models.')) + self.model_status.setToolTip('') + except Exception as e: + self.model_status.setText(_('Failed to refresh models: {}').format(e)) + self.model_status.setToolTip('') + + @property + def api_url(self) -> str: + return self.api_url_edit.text().strip() + + @property + def api_key(self) -> str: + return self.api_key_edit.text().strip() + + @property + def text_model(self) -> str: + return self.model_combo.currentText().strip() + + @property + def timeout(self) -> int: + return self.timeout_sb.value() + + @property + def temperature(self) -> float: + return self.temp_sb.value() + + @property + def headers(self) -> tuple[tuple[str, str], ...]: + ans = [] + for line in self.headers_edit.toPlainText().splitlines(): + if line := line.strip(): + key, sep, val = line.partition(':') + key, val = key.strip(), val.strip() + if key and sep and val: + ans.append((key, val)) + return tuple(ans) + + @property + def settings(self) -> dict[str, Any]: + ans = { + 'api_url': self.api_url, + 'api_key': encode_secret(self.api_key), + 'text_model': self.text_model, + 'timeout': self.timeout, + 'temperature': self.temperature, + } + if self.headers: + ans['headers'] = self.headers + return ans + + @property + def is_ready_for_use(self) -> bool: + return bool(self.api_url and self.text_model) + + def validate(self) -> bool: + if not self.api_url: + error_dialog(self, _('No API URL'), _('You must specify the URL of the OpenAI compatible API endpoint.'), show=True) + return False + if not self.text_model: + error_dialog(self, _('No model specified'), _('You must specify a model ID to use for text based tasks.'), show=True) + return False + return True + + def save_settings(self): + set_prefs_for_provider(OpenAICompatible.name, self.settings) + + +if __name__ == '__main__': + configure(OpenAICompatible.name) diff --git a/src/calibre/customize/builtins.py b/src/calibre/customize/builtins.py index beacada7c9..e7904072b3 100644 --- a/src/calibre/customize/builtins.py +++ b/src/calibre/customize/builtins.py @@ -8,6 +8,7 @@ from calibre.ai.github import GitHubAI from calibre.ai.google import GoogleAI from calibre.ai.lm_studio import LMStudioAI from calibre.ai.ollama import OllamaAI +from calibre.ai.openai_compatible import OpenAICompatible from calibre.ai.open_router import OpenRouterAI from calibre.constants import numeric_version from calibre.customize import FileTypePlugin, InterfaceActionBase, MetadataReaderPlugin, MetadataWriterPlugin, PreferencesPlugin, StoreBase @@ -1986,7 +1987,7 @@ plugins += [ # }}} -plugins.extend((OpenRouterAI, GoogleAI, GitHubAI, OllamaAI, LMStudioAI)) +plugins.extend((OpenRouterAI, GoogleAI, GitHubAI, OllamaAI, LMStudioAI, OpenAICompatible)) if __name__ == '__main__': # Test load speed