mirror of
https://github.com/kovidgoyal/calibre.git
synced 2026-04-18 23:08:48 -04:00
add OpenAI-compatible AI provider
This commit is contained in:
parent
4c1041b23f
commit
f84f3ca7ba
20
src/calibre/ai/openai_compatible/__init__.py
Normal file
20
src/calibre/ai/openai_compatible/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
# License: GPLv3 Copyright: 2026, OpenAI
|
||||
|
||||
from calibre.customize import AIProviderPlugin
|
||||
|
||||
|
||||
class OpenAICompatible(AIProviderPlugin):
|
||||
name = 'OpenAI compatible'
|
||||
version = (1, 0, 0)
|
||||
description = _(
|
||||
'Generic OpenAI compatible AI services. Use this to connect calibre to self-hosted or third-party services'
|
||||
' that implement the OpenAI chat completions API.'
|
||||
)
|
||||
author = 'OpenAI'
|
||||
builtin_live_module_name = 'calibre.ai.openai_compatible.backend'
|
||||
|
||||
@property
|
||||
def capabilities(self):
|
||||
from calibre.ai import AICapabilities
|
||||
return AICapabilities.text_to_text
|
||||
215
src/calibre/ai/openai_compatible/backend.py
Normal file
215
src/calibre/ai/openai_compatible/backend.py
Normal file
@ -0,0 +1,215 @@
|
||||
#!/usr/bin/env python
|
||||
# License: GPLv3 Copyright: 2026, OpenAI
|
||||
|
||||
import json
|
||||
import posixpath
|
||||
from collections.abc import Iterable, Iterator, Sequence
|
||||
from functools import lru_cache
|
||||
from typing import Any, NamedTuple
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
from urllib.request import Request
|
||||
|
||||
from calibre.ai import ChatMessage, ChatMessageType, ChatResponse, ResultBlocked, ResultBlockReason
|
||||
from calibre.ai.openai_compatible import OpenAICompatible
|
||||
from calibre.ai.prefs import decode_secret, pref_for_provider
|
||||
from calibre.ai.utils import chat_with_error_handler, develop_text_chat, download_data, read_streaming_response
|
||||
|
||||
module_version = 1
|
||||
|
||||
|
||||
def pref(key: str, defval: Any = None) -> Any:
|
||||
return pref_for_provider(OpenAICompatible.name, key, defval)
|
||||
|
||||
|
||||
def is_ready_for_use() -> bool:
|
||||
return bool(pref('api_url') and pref('text_model'))
|
||||
|
||||
|
||||
class Model(NamedTuple):
|
||||
id: str
|
||||
owner: str
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, x: dict[str, Any]) -> 'Model':
|
||||
return cls(id=x['id'], owner=x.get('owned_by', 'remote'))
|
||||
|
||||
|
||||
def api_url(path: str = '', use_api_url: str | None = None) -> str:
|
||||
base = (pref('api_url') if use_api_url is None else use_api_url) or ''
|
||||
purl = urlparse(base)
|
||||
base_path = (purl.path or '').rstrip('/')
|
||||
if not base_path:
|
||||
base_path = '/v1'
|
||||
elif base_path.endswith('/chat/completions'):
|
||||
base_path = base_path[:-len('/chat/completions')]
|
||||
elif base_path.endswith('/models'):
|
||||
base_path = base_path[:-len('/models')]
|
||||
if path:
|
||||
base_path = posixpath.join(base_path, path)
|
||||
purl = purl._replace(path=base_path)
|
||||
return urlunparse(purl)
|
||||
|
||||
|
||||
def raw_api_key(use_api_key: str | None = None) -> str:
|
||||
key = pref('api_key') if use_api_key is None else use_api_key
|
||||
return decode_secret(key) if key else ''
|
||||
|
||||
|
||||
@lru_cache(32)
|
||||
def request_headers(
|
||||
use_api_key: str | None = None, use_headers: Sequence[tuple[str, str]] | None = None
|
||||
) -> tuple[tuple[str, str], ...]:
|
||||
ans = [('Content-Type', 'application/json')]
|
||||
extra_headers = pref('headers', ()) if use_headers is None else use_headers
|
||||
extra_headers = tuple(extra_headers or ())
|
||||
has_auth = False
|
||||
for key, val in extra_headers:
|
||||
if key.lower() == 'authorization':
|
||||
has_auth = True
|
||||
ans.append((key, val))
|
||||
if api_key := raw_api_key(use_api_key):
|
||||
if not has_auth:
|
||||
ans.insert(0, ('Authorization', f'Bearer {api_key}'))
|
||||
return tuple(ans)
|
||||
|
||||
|
||||
@lru_cache(8)
|
||||
def get_available_models(
|
||||
use_api_url: str | None = None, use_api_key: str | None = None, use_headers: Sequence[tuple[str, str]] | None = None
|
||||
) -> dict[str, Model]:
|
||||
url = api_url('models', use_api_url)
|
||||
data = json.loads(download_data(url, request_headers(use_api_key, use_headers)))
|
||||
ans = {}
|
||||
if 'data' in data:
|
||||
for model_data in data['data']:
|
||||
model = Model.from_dict(model_data)
|
||||
ans[model.id] = model
|
||||
return ans
|
||||
|
||||
|
||||
def human_readable_model_name(model_id: str) -> str:
|
||||
return model_id
|
||||
|
||||
|
||||
def config_widget():
|
||||
from calibre.ai.openai_compatible.config import ConfigWidget
|
||||
return ConfigWidget()
|
||||
|
||||
|
||||
def save_settings(config_widget):
|
||||
config_widget.save_settings()
|
||||
|
||||
|
||||
def for_assistant(self: ChatMessage) -> dict[str, Any]:
|
||||
if self.type not in (ChatMessageType.assistant, ChatMessageType.system, ChatMessageType.user, ChatMessageType.developer):
|
||||
raise ValueError(f'Unsupported message type: {self.type}')
|
||||
return {'role': self.type.value, 'content': self.query}
|
||||
|
||||
|
||||
def coerce_text(value: Any) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
text = value.get('text')
|
||||
return text if isinstance(text, str) else ''
|
||||
if isinstance(value, list):
|
||||
return ''.join(filter(None, map(coerce_text, value)))
|
||||
return ''
|
||||
|
||||
|
||||
def chat_request(
|
||||
data: dict[str, Any], url_override: str | None = None, use_api_key: str | None = None,
|
||||
use_headers: Sequence[tuple[str, str]] | None = None
|
||||
) -> Request:
|
||||
url = api_url('chat/completions', url_override)
|
||||
return Request(url, data=json.dumps(data).encode('utf-8'), headers=dict(request_headers(use_api_key, use_headers)), method='POST')
|
||||
|
||||
|
||||
def as_chat_responses(d: dict[str, Any], model_id: str) -> Iterator[ChatResponse]:
|
||||
blocked = False
|
||||
for choice in d.get('choices', ()):
|
||||
delta = choice.get('delta') or {}
|
||||
content = coerce_text(delta.get('content'))
|
||||
reasoning = coerce_text(delta.get('reasoning_content'))
|
||||
role = delta.get('role') or 'assistant'
|
||||
if content or reasoning:
|
||||
yield ChatResponse(
|
||||
content=content, reasoning=reasoning, type=ChatMessageType(role), plugin_name=OpenAICompatible.name
|
||||
)
|
||||
if choice.get('finish_reason') == 'content_filter':
|
||||
blocked = True
|
||||
if blocked:
|
||||
yield ChatResponse(exception=ResultBlocked(ResultBlockReason.safety), plugin_name=OpenAICompatible.name)
|
||||
return
|
||||
if usage := d.get('usage'):
|
||||
yield ChatResponse(
|
||||
has_metadata=True, provider=OpenAICompatible.name, model=d.get('model') or model_id,
|
||||
plugin_name=OpenAICompatible.name
|
||||
)
|
||||
|
||||
|
||||
def text_chat_implementation(messages: Iterable[ChatMessage], use_model: str = '') -> Iterator[ChatResponse]:
|
||||
model_id = use_model or pref('text_model')
|
||||
temperature = pref('temperature', 0.7)
|
||||
data = {
|
||||
'model': model_id,
|
||||
'messages': [for_assistant(m) for m in messages],
|
||||
'stream': True,
|
||||
'temperature': temperature,
|
||||
}
|
||||
rq = chat_request(data)
|
||||
seen_metadata = False
|
||||
for data in read_streaming_response(rq, OpenAICompatible.name, timeout=pref('timeout', 120)):
|
||||
for response in as_chat_responses(data, model_id):
|
||||
if response.has_metadata:
|
||||
seen_metadata = True
|
||||
yield response
|
||||
if response.exception:
|
||||
return
|
||||
if not seen_metadata:
|
||||
yield ChatResponse(has_metadata=True, provider=OpenAICompatible.name, model=model_id, plugin_name=OpenAICompatible.name)
|
||||
|
||||
|
||||
def text_chat(messages: Iterable[ChatMessage], use_model: str = '') -> Iterator[ChatResponse]:
|
||||
yield from chat_with_error_handler(text_chat_implementation(messages, use_model))
|
||||
|
||||
|
||||
def develop(use_model: str = '', msg: str = '') -> None:
|
||||
m = (ChatMessage(msg),) if msg else ()
|
||||
develop_text_chat(text_chat, use_model, messages=m)
|
||||
|
||||
|
||||
def find_tests():
|
||||
import unittest
|
||||
|
||||
class TestOpenAICompatibleBackend(unittest.TestCase):
|
||||
|
||||
def test_api_url_normalization(self):
|
||||
self.assertEqual(api_url('models', 'http://localhost:1234'), 'http://localhost:1234/v1/models')
|
||||
self.assertEqual(api_url('models', 'http://localhost:1234/v1'), 'http://localhost:1234/v1/models')
|
||||
self.assertEqual(api_url('models', 'https://example.com/custom/api'), 'https://example.com/custom/api/models')
|
||||
self.assertEqual(api_url('chat/completions', 'https://ark.cn-beijing.volces.com/api/v3'), 'https://ark.cn-beijing.volces.com/api/v3/chat/completions')
|
||||
self.assertEqual(api_url('chat/completions', 'https://ark.cn-beijing.volces.com/api/v3/chat/completions'), 'https://ark.cn-beijing.volces.com/api/v3/chat/completions')
|
||||
|
||||
def test_request_headers_allows_missing_headers_pref(self):
|
||||
headers = request_headers()
|
||||
self.assertEqual(headers, (('Content-Type', 'application/json'),))
|
||||
|
||||
def test_parsing_stream_deltas(self):
|
||||
responses = tuple(as_chat_responses({
|
||||
'model': 'demo-model',
|
||||
'choices': [
|
||||
{'delta': {'role': 'assistant', 'content': 'Hello', 'reasoning_content': 'Think'}, 'finish_reason': None}
|
||||
],
|
||||
'usage': {'total_tokens': 42},
|
||||
}, 'demo-model'))
|
||||
self.assertEqual(responses[0].content, 'Hello')
|
||||
self.assertEqual(responses[0].reasoning, 'Think')
|
||||
self.assertTrue(responses[-1].has_metadata)
|
||||
self.assertEqual(responses[-1].model, 'demo-model')
|
||||
|
||||
return unittest.defaultTestLoader.loadTestsFromTestCase(TestOpenAICompatibleBackend)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
develop()
|
||||
199
src/calibre/ai/openai_compatible/config.py
Normal file
199
src/calibre/ai/openai_compatible/config.py
Normal file
@ -0,0 +1,199 @@
|
||||
#!/usr/bin/env python
|
||||
# License: GPLv3 Copyright: 2026, OpenAI
|
||||
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from qt.core import (
|
||||
QComboBox,
|
||||
QCompleter,
|
||||
QDoubleSpinBox,
|
||||
QFormLayout,
|
||||
QHBoxLayout,
|
||||
QLabel,
|
||||
QLineEdit,
|
||||
QListView,
|
||||
QPlainTextEdit,
|
||||
QPushButton,
|
||||
QSpinBox,
|
||||
Qt,
|
||||
QWidget,
|
||||
)
|
||||
|
||||
from calibre.ai.openai_compatible import OpenAICompatible
|
||||
from calibre.ai.prefs import decode_secret, encode_secret, pref_for_provider, set_prefs_for_provider
|
||||
from calibre.ai.utils import configure, plugin_for_name
|
||||
from calibre.gui2 import error_dialog
|
||||
from calibre.gui2.widgets import BusyCursor
|
||||
|
||||
pref = partial(pref_for_provider, OpenAICompatible.name)
|
||||
|
||||
|
||||
class ConfigWidget(QWidget):
|
||||
|
||||
def __init__(self, parent: QWidget | None = None):
|
||||
super().__init__(parent)
|
||||
l = QFormLayout(self)
|
||||
l.setFieldGrowthPolicy(QFormLayout.FieldGrowthPolicy.ExpandingFieldsGrow)
|
||||
|
||||
la = QLabel('<p>' + _(
|
||||
'Connect calibre to any self-hosted or third-party service that implements the OpenAI compatible'
|
||||
' <code>/v1/chat/completions</code> API. This is useful for gateways, local servers and other'
|
||||
' services that are not listed as dedicated providers.'
|
||||
))
|
||||
la.setWordWrap(True)
|
||||
l.addRow(la)
|
||||
|
||||
self.api_url_edit = a = QLineEdit(self)
|
||||
a.setClearButtonEnabled(True)
|
||||
a.setPlaceholderText(_('For example: {}').format('https://example.com/v1'))
|
||||
l.addRow(_('API &URL:'), a)
|
||||
a.setText(pref('api_url') or '')
|
||||
|
||||
self.api_key_edit = ak = QLineEdit(self)
|
||||
ak.setClearButtonEnabled(True)
|
||||
ak.setPlaceholderText(_('Optional. Sent as Authorization: Bearer <key>'))
|
||||
l.addRow(_('API &key:'), ak)
|
||||
if key := pref('api_key'):
|
||||
ak.setText(decode_secret(key))
|
||||
|
||||
self.headers_edit = he = QPlainTextEdit(self)
|
||||
he.setPlaceholderText(_('Optional HTTP headers, one per line, in the format: Header-Name: Value'))
|
||||
l.addRow(_('HTTP &headers:'), he)
|
||||
he.setPlainText('\n'.join(f'{k}: {v}' for (k, v) in pref('headers') or ()))
|
||||
|
||||
self.timeout_sb = t = QSpinBox(self)
|
||||
t.setRange(15, 600)
|
||||
t.setSingleStep(1)
|
||||
t.setSuffix(_(' seconds'))
|
||||
t.setValue(pref('timeout', 120))
|
||||
l.addRow(_('&Timeout:'), t)
|
||||
|
||||
self.temp_sb = temp = QDoubleSpinBox(self)
|
||||
temp.setRange(0.0, 2.0)
|
||||
temp.setSingleStep(0.1)
|
||||
temp.setValue(pref('temperature', 0.7))
|
||||
temp.setToolTip(_('Controls randomness. Lower values are more deterministic.'))
|
||||
l.addRow(_('T&emperature:'), temp)
|
||||
|
||||
w = QWidget(self)
|
||||
h = QHBoxLayout(w)
|
||||
h.setContentsMargins(0, 0, 0, 0)
|
||||
|
||||
self.model_combo = mc = QComboBox(w)
|
||||
mc.setEditable(True)
|
||||
mc.setInsertPolicy(QComboBox.NoInsert)
|
||||
mc.setView(QListView(mc))
|
||||
mc.setSizeAdjustPolicy(QComboBox.AdjustToContentsOnFirstShow)
|
||||
completer = QCompleter(mc)
|
||||
completer.setCaseSensitivity(Qt.CaseInsensitive)
|
||||
mc.setCompleter(completer)
|
||||
|
||||
if saved_model := pref('text_model') or '':
|
||||
mc.addItem(saved_model)
|
||||
mc.setCurrentText(saved_model)
|
||||
|
||||
self.refresh_btn = rb = QPushButton(_('&Refresh models'), w)
|
||||
rb.clicked.connect(self.refresh_models)
|
||||
h.addWidget(mc, stretch=10)
|
||||
h.addWidget(rb)
|
||||
l.addRow(_('Model for &text tasks:'), w)
|
||||
|
||||
self.model_status = ms = QLabel('', self)
|
||||
ms.setWordWrap(True)
|
||||
ms.setTextInteractionFlags(Qt.TextInteractionFlag.TextSelectableByMouse)
|
||||
l.addRow('', ms)
|
||||
|
||||
def refresh_models(self):
|
||||
with BusyCursor():
|
||||
try:
|
||||
plugin = plugin_for_name(OpenAICompatible.name)
|
||||
backend = plugin.builtin_live_module
|
||||
backend.get_available_models.cache_clear()
|
||||
encoded_key = encode_secret(self.api_key) if self.api_key else ''
|
||||
models_dict = backend.get_available_models(self.api_url, encoded_key, self.headers)
|
||||
current_text = self.text_model
|
||||
model_ids = sorted(models_dict, key=lambda x: x.lower())
|
||||
self.model_combo.blockSignals(True)
|
||||
self.model_combo.clear()
|
||||
for model_id in model_ids:
|
||||
self.model_combo.addItem(model_id)
|
||||
self.model_combo.setCurrentText(current_text or (model_ids[0] if model_ids else ''))
|
||||
self.model_combo.blockSignals(False)
|
||||
if model_ids:
|
||||
sample = ', '.join(model_ids[:3])
|
||||
msg = _('Found {} models. e.g.: {}').format(len(model_ids), sample)
|
||||
if len(model_ids) > 3:
|
||||
msg += _(' (and more)')
|
||||
self.model_status.setText(msg)
|
||||
self.model_status.setToolTip('\n'.join(model_ids))
|
||||
else:
|
||||
self.model_status.setText(_('The server responded, but returned no models.'))
|
||||
self.model_status.setToolTip('')
|
||||
except Exception as e:
|
||||
self.model_status.setText(_('Failed to refresh models: {}').format(e))
|
||||
self.model_status.setToolTip('')
|
||||
|
||||
@property
|
||||
def api_url(self) -> str:
|
||||
return self.api_url_edit.text().strip()
|
||||
|
||||
@property
|
||||
def api_key(self) -> str:
|
||||
return self.api_key_edit.text().strip()
|
||||
|
||||
@property
|
||||
def text_model(self) -> str:
|
||||
return self.model_combo.currentText().strip()
|
||||
|
||||
@property
|
||||
def timeout(self) -> int:
|
||||
return self.timeout_sb.value()
|
||||
|
||||
@property
|
||||
def temperature(self) -> float:
|
||||
return self.temp_sb.value()
|
||||
|
||||
@property
|
||||
def headers(self) -> tuple[tuple[str, str], ...]:
|
||||
ans = []
|
||||
for line in self.headers_edit.toPlainText().splitlines():
|
||||
if line := line.strip():
|
||||
key, sep, val = line.partition(':')
|
||||
key, val = key.strip(), val.strip()
|
||||
if key and sep and val:
|
||||
ans.append((key, val))
|
||||
return tuple(ans)
|
||||
|
||||
@property
|
||||
def settings(self) -> dict[str, Any]:
|
||||
ans = {
|
||||
'api_url': self.api_url,
|
||||
'api_key': encode_secret(self.api_key),
|
||||
'text_model': self.text_model,
|
||||
'timeout': self.timeout,
|
||||
'temperature': self.temperature,
|
||||
}
|
||||
if self.headers:
|
||||
ans['headers'] = self.headers
|
||||
return ans
|
||||
|
||||
@property
|
||||
def is_ready_for_use(self) -> bool:
|
||||
return bool(self.api_url and self.text_model)
|
||||
|
||||
def validate(self) -> bool:
|
||||
if not self.api_url:
|
||||
error_dialog(self, _('No API URL'), _('You must specify the URL of the OpenAI compatible API endpoint.'), show=True)
|
||||
return False
|
||||
if not self.text_model:
|
||||
error_dialog(self, _('No model specified'), _('You must specify a model ID to use for text based tasks.'), show=True)
|
||||
return False
|
||||
return True
|
||||
|
||||
def save_settings(self):
|
||||
set_prefs_for_provider(OpenAICompatible.name, self.settings)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
configure(OpenAICompatible.name)
|
||||
@ -8,6 +8,7 @@ from calibre.ai.github import GitHubAI
|
||||
from calibre.ai.google import GoogleAI
|
||||
from calibre.ai.lm_studio import LMStudioAI
|
||||
from calibre.ai.ollama import OllamaAI
|
||||
from calibre.ai.openai_compatible import OpenAICompatible
|
||||
from calibre.ai.open_router import OpenRouterAI
|
||||
from calibre.constants import numeric_version
|
||||
from calibre.customize import FileTypePlugin, InterfaceActionBase, MetadataReaderPlugin, MetadataWriterPlugin, PreferencesPlugin, StoreBase
|
||||
@ -1986,7 +1987,7 @@ plugins += [
|
||||
|
||||
# }}}
|
||||
|
||||
plugins.extend((OpenRouterAI, GoogleAI, GitHubAI, OllamaAI, LMStudioAI))
|
||||
plugins.extend((OpenRouterAI, GoogleAI, GitHubAI, OllamaAI, LMStudioAI, OpenAICompatible))
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Test load speed
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user