mirror of
https://github.com/kovidgoyal/calibre.git
synced 2026-03-31 22:32:28 -04:00
More work on prefs dialog for OpenRouter
This commit is contained in:
parent
51a301c50b
commit
ebea8de58b
18
src/calibre/ai/__init__.py
Normal file
18
src/calibre/ai/__init__.py
Normal file
@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
# License: GPLv3 Copyright: 2025, Kovid Goyal <kovid at kovidgoyal.net>
|
||||
|
||||
from enum import Flag, auto
|
||||
|
||||
|
||||
class AICapabilities(Flag):
|
||||
none = auto()
|
||||
text_to_text = auto()
|
||||
text_to_image = auto()
|
||||
|
||||
@property
|
||||
def supports_text_to_text(self) -> bool:
|
||||
return AICapabilities.text_to_text in self
|
||||
|
||||
@property
|
||||
def supports_text_to_image(self) -> bool:
|
||||
return AICapabilities.text_to_image in self
|
||||
@ -1,26 +1,69 @@
|
||||
#!/usr/bin/env python
|
||||
# License: GPLv3 Copyright: 2025, Kovid Goyal <kovid at kovidgoyal.net>
|
||||
|
||||
from qt.core import QComboBox, QDialog, QGroupBox, QHBoxLayout, QLabel, QStackedLayout, QVBoxLayout, QWidget
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from calibre.utils.config import JSONConfig
|
||||
from calibre.ai import AICapabilities
|
||||
from calibre.ai.prefs import prefs
|
||||
from calibre.customize.ui import available_ai_provider_plugins
|
||||
from calibre.gui2 import Application, error_dialog
|
||||
|
||||
|
||||
@lru_cache(2)
|
||||
def prefs() -> JSONConfig:
|
||||
ans = JSONConfig('ai')
|
||||
ans.defaults['providers'] = {}
|
||||
return ans
|
||||
class ConfigureAI(QWidget):
|
||||
|
||||
def __init__(self, purpose: AICapabilities = AICapabilities.text_to_text, parent: QWidget | None = None):
|
||||
super().__init__(parent)
|
||||
plugins = tuple(p for p in available_ai_provider_plugins() if p.capabilities & purpose == purpose)
|
||||
self.available_plugins = plugins
|
||||
self.purpose = purpose
|
||||
self.plugin_config_widgets: tuple[QWidget, ...] = tuple(p.config_widget() for p in plugins)
|
||||
v = QVBoxLayout(self)
|
||||
self.gb = QGroupBox(self)
|
||||
self.stack = s = QStackedLayout(self.gb)
|
||||
for pc in self.plugin_config_widgets:
|
||||
pc.setParent(self)
|
||||
s.addWidget(pc)
|
||||
if len(plugins) > 1:
|
||||
self.provider_combo = pcb = QComboBox(self)
|
||||
pcb.addItems([p.name for p in plugins])
|
||||
la = QLabel(_('AI &provider:'))
|
||||
la.setBuddy(pcb)
|
||||
h = QHBoxLayout()
|
||||
h.addWidget(la), h.addWidget(pcb), h.addStretch()
|
||||
v.addLayout(h)
|
||||
pcb.currentIndexChanged.connect(self.stack.setCurrentIndex)
|
||||
idx = pcb.findText(prefs()['purpose_map'].get(str(self.purpose), ''))
|
||||
pcb.setCurrentIndex(max(0, idx))
|
||||
elif len(plugins) == 1:
|
||||
self.gb.setTitle(_('Configure AI provider: {}').format(plugins[0].name))
|
||||
else:
|
||||
self.none_label = la = QLabel(_('No AI providers found that have the capabilities: {}. Make sure you have not'
|
||||
' disabled some AI provider plugins').format(purpose))
|
||||
s.addWidget()
|
||||
v.addWidget(self.gb)
|
||||
|
||||
def commit(self) -> bool:
|
||||
if not self.available_plugins:
|
||||
error_dialog(self, _('No AI providers'), self.none_label.text(), show=True)
|
||||
return False
|
||||
if len(self.available_plugins) == 1:
|
||||
idx = 0
|
||||
else:
|
||||
idx = self.provider_combo.currentIndex()
|
||||
p, w = self.available_plugins[idx], self.plugin_config_widgets[idx]
|
||||
if not w.validate():
|
||||
return False
|
||||
p.save_settings(w)
|
||||
pmap = prefs()['purpose_map']
|
||||
pmap[str(self.purpose)] = p.name
|
||||
prefs().set('purpose_map', pmap)
|
||||
return True
|
||||
|
||||
|
||||
def pref_for_provider(name: str, key: str, defval: Any = None) -> Any:
|
||||
return prefs()['providers'].get(key, defval)
|
||||
|
||||
|
||||
def set_prefs_for_provider(name: str, pref_map: dict[str, Any]) -> None:
|
||||
p = prefs()
|
||||
p['providers'][name] = deepcopy(pref_map)
|
||||
p.set('providers', p['providers'])
|
||||
if __name__ == '__main__':
|
||||
app = Application([])
|
||||
d = QDialog()
|
||||
v = QVBoxLayout(d)
|
||||
w = ConfigureAI(parent=d)
|
||||
v.addWidget(w)
|
||||
d.exec()
|
||||
|
||||
@ -11,3 +11,8 @@ class OpenRouterAI(AIProviderPlugin):
|
||||
description = _('AI services from OpenRouter.ai. Allows choosing from hundreds of different AI models to query.')
|
||||
author = 'Kovid Goyal'
|
||||
builtin_live_module_name = 'calibre.ai.open_router.backend'
|
||||
|
||||
@property
|
||||
def capabilities(self):
|
||||
from calibre.ai import AICapabilities
|
||||
return AICapabilities.text_to_text | AICapabilities.text_to_image
|
||||
|
||||
@ -11,6 +11,7 @@ from threading import Thread
|
||||
from typing import NamedTuple
|
||||
|
||||
from calibre import browser
|
||||
from calibre.ai import AICapabilities
|
||||
from calibre.constants import __version__, cache_dir
|
||||
from calibre.utils.lock import SingleInstance
|
||||
|
||||
@ -106,18 +107,25 @@ class Model(NamedTuple):
|
||||
pricing: Pricing
|
||||
parameters: tuple[str, ...]
|
||||
is_moderated: bool
|
||||
supports_text_to_text: bool
|
||||
capabilities: AICapabilities
|
||||
tokenizer: str
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, x: dict[str, object]) -> 'Model':
|
||||
arch = x['architecture']
|
||||
capabilities = AICapabilities.none
|
||||
if 'text' in arch['input_modalities']:
|
||||
if 'text' in arch['output_modalities']:
|
||||
capabilities |= AICapabilities.text_to_text
|
||||
if 'image' in arch['output_modalities']:
|
||||
capabilities |= AICapabilities.text_to_image
|
||||
|
||||
return Model(
|
||||
name=x['name'], id=x['id'], created=datetime.datetime.fromtimestamp(x['created'], datetime.timezone.utc),
|
||||
description=x['description'], context_length=x['context_length'],
|
||||
parameters=tuple(x['supported_parameters']), pricing=Pricing.from_dict(x['pricing']),
|
||||
is_moderated=x['top_provider']['is_moderated'], tokenizer=arch['tokenizer'],
|
||||
supports_text_to_text='text' in arch['input_modalities'] and 'text' in arch['output_modalities'],
|
||||
capabilities=capabilities,
|
||||
)
|
||||
|
||||
|
||||
@ -129,6 +137,15 @@ def parse_models_list(entries) -> dict[str, Model]:
|
||||
return ans
|
||||
|
||||
|
||||
def config_widget():
|
||||
from calibre.ai.open_router.config import ConfigWidget
|
||||
return ConfigWidget()
|
||||
|
||||
|
||||
def save_settings(config_widget):
|
||||
config_widget.save_settings()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from pprint import pprint
|
||||
for m in get_available_models().values():
|
||||
|
||||
@ -2,10 +2,13 @@
|
||||
# License: GPLv3 Copyright: 2025, Kovid Goyal <kovid at kovidgoyal.net>
|
||||
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from qt.core import QFormLayout, QHBoxLayout, QLabel, QLineEdit, QPushButton, QWidget
|
||||
from qt.core import QAbstractListModel, QFormLayout, QHBoxLayout, QLabel, QLineEdit, QPushButton, QSortFilterProxyModel, Qt, QWidget, pyqtSignal
|
||||
|
||||
from calibre.ai.config import pref_for_provider
|
||||
from calibre.ai.prefs import pref_for_provider, set_prefs_for_provider
|
||||
from calibre.customize.ui import available_ai_provider_plugins
|
||||
from calibre.gui2 import error_dialog
|
||||
|
||||
from . import OpenRouterAI
|
||||
|
||||
@ -14,24 +17,81 @@ pref = partial(pref_for_provider, OpenRouterAI.name)
|
||||
|
||||
class Model(QWidget):
|
||||
|
||||
select_model = pyqtSignal(str, bool)
|
||||
|
||||
def __init__(self, for_text: bool = True, parent: QWidget | None = None):
|
||||
super().__init__(parent)
|
||||
l = QHBoxLayout(self)
|
||||
l.setContentsMargins(0, 0, 0, 0)
|
||||
self.for_text = for_text
|
||||
self.model_id, self.model_name = pref(
|
||||
'text_model' if for_text else 'text_to_image_model', ('', _('Automatic (free)')))
|
||||
'text_model' if for_text else 'text_to_image_model', ('', _('Automatic (low cost)')))
|
||||
self.la = la = QLabel(self.model_name)
|
||||
self.setToolTip(_('The model to use for text related tasks') if for_text else _(
|
||||
'The model to use for generating iamges from text'))
|
||||
'The model to use for generating images from text'))
|
||||
self.setToolTip(self.toolTip() + '\n\n' + _(
|
||||
'If not specified an appropriate free to use model is chosen automatically.\n'
|
||||
'If no free model is available then cheaper ones are preferred.'))
|
||||
self.b = b = QPushButton(_('&Select'))
|
||||
self.b = b = QPushButton(_('&Change'))
|
||||
b.setToolTip(_('Choose a model'))
|
||||
l.addWidget(la), l.addWidget(b)
|
||||
b.clicked.connect(self.select_model)
|
||||
b.clicked.connect(self._select_model)
|
||||
|
||||
def select_model(self):
|
||||
pass
|
||||
def _select_model(self):
|
||||
self.select_model.emit(self.model_id, self.for_text)
|
||||
|
||||
|
||||
class ModelsModel(QAbstractListModel):
|
||||
|
||||
def __init__(self, parent: QWidget | None = None):
|
||||
super().__init__(parent)
|
||||
for plugin in available_ai_provider_plugins():
|
||||
if plugin.name == OpenRouterAI.name:
|
||||
self.backend = plugin.builtin_live_module
|
||||
break
|
||||
else:
|
||||
raise ValueError('Could not find OpenRouterAI plugin')
|
||||
self.all_models_map = self.backend.get_available_models()
|
||||
self.all_models = sorted(self.all_models_map.values(), key=lambda m: m.created, reverse=True)
|
||||
|
||||
def rowCount(self, parent):
|
||||
return len(self.all_models)
|
||||
|
||||
def data(self, index, role):
|
||||
try:
|
||||
m = self.all_models[index.row()]
|
||||
except IndexError:
|
||||
return None
|
||||
if role == Qt.ItemDataRole.DisplayRole:
|
||||
return m.name
|
||||
if role == Qt.ItemDataRole.UserRole:
|
||||
return m
|
||||
return None
|
||||
|
||||
|
||||
class ProxyModels(QSortFilterProxyModel):
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.source_model = ModelsModel(self)
|
||||
self.setSourceModel(self.source_model)
|
||||
self.filters = []
|
||||
|
||||
def filterAcceptsRow(self, source_row: int, source_parent) -> bool:
|
||||
try:
|
||||
m = self.source_model.all_models[source_row]
|
||||
except IndexError:
|
||||
return False
|
||||
for f in self.filters:
|
||||
if not f(m):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class ChooseModel(QWidget):
|
||||
|
||||
def __init__(self, parent: QWidget | None = None):
|
||||
super().__init__(parent)
|
||||
|
||||
|
||||
class ConfigWidget(QWidget):
|
||||
@ -39,13 +99,45 @@ class ConfigWidget(QWidget):
|
||||
def __init__(self, parent: QWidget | None = None):
|
||||
super().__init__(parent)
|
||||
l = QFormLayout(self)
|
||||
l.setFieldGrowthPolicy(QFormLayout.FieldGrowthPolicy.ExpandingFieldsGrow)
|
||||
la = QLabel('<p>'+_(
|
||||
'You have to create an account at {}, then generate an'
|
||||
' API key and purchase a token amount of credits. After that, you can use any AI'
|
||||
' model you like, including free ones.').format('<a href="https://openrouter.ai">OpenRouter.ai</a>'))
|
||||
'You have to create an account at {0}, then generate an'
|
||||
' API key and purchase a token amount of credits. After that, you can use any '
|
||||
' <a href="{1}">AI model</a> you like, including free ones.'
|
||||
).format('<a href="https://openrouter.ai">OpenRouter.ai</a>', 'https://openrouter.ai/rankings'))
|
||||
la.setWordWrap(True)
|
||||
la.setOpenExternalLinks(True)
|
||||
l.addRow(la)
|
||||
self.api_key_edit = a = QLineEdit(self)
|
||||
a.setPlaceholderText(_('An API key is required to use OpenRouter'))
|
||||
l.addRow(_('API &key:'), a)
|
||||
if key := pref('api_key'):
|
||||
a.setText(key)
|
||||
self.text_model = tm = Model(parent=self)
|
||||
tm.select_model.connect(self.select_model)
|
||||
l.addRow(_('Model for &text tasks:'), tm)
|
||||
self.choose_model = cm = ChooseModel(self)
|
||||
cm.setVisible(False)
|
||||
l.addRow(cm)
|
||||
|
||||
def select_model(self, model_id: str, for_text: bool) -> None:
|
||||
self.model_choice_target = self.sender()
|
||||
|
||||
@property
|
||||
def api_key(self) -> str:
|
||||
return self.api_key_edit.text().strip()
|
||||
|
||||
@property
|
||||
def settings(self) -> dict[str, Any]:
|
||||
return {'api_key': self.api_key}
|
||||
|
||||
def validate(self) -> bool:
|
||||
if self.api_key:
|
||||
return True
|
||||
error_dialog(self, _('No API key'), _(
|
||||
'You must supply an API key to use OpenRouter. Remember to also buy a few credits, even if you'
|
||||
' plan on using only free models.'), show=True)
|
||||
return False
|
||||
|
||||
def save_settings(self):
|
||||
set_prefs_for_provider(OpenRouterAI.name, self.settings)
|
||||
|
||||
26
src/calibre/ai/prefs.py
Normal file
26
src/calibre/ai/prefs.py
Normal file
@ -0,0 +1,26 @@
|
||||
#!/usr/bin/env python
|
||||
# License: GPLv3 Copyright: 2025, Kovid Goyal <kovid at kovidgoyal.net>
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from calibre.utils.config import JSONConfig
|
||||
|
||||
|
||||
@lru_cache(2)
|
||||
def prefs() -> JSONConfig:
|
||||
ans = JSONConfig('ai')
|
||||
ans.defaults['providers'] = {}
|
||||
ans.defaults['purpose_map'] = {}
|
||||
return ans
|
||||
|
||||
|
||||
def pref_for_provider(name: str, key: str, defval: Any = None) -> Any:
|
||||
return prefs()['providers'].get(key, defval)
|
||||
|
||||
|
||||
def set_prefs_for_provider(name: str, pref_map: dict[str, Any]) -> None:
|
||||
p = prefs()
|
||||
p['providers'][name] = deepcopy(pref_map)
|
||||
p.set('providers', p['providers'])
|
||||
@ -831,6 +831,12 @@ class AIProviderPlugin(Plugin): # {{{
|
||||
# Used by builtin AI Provider plugins to live load the backend code
|
||||
builtin_live_module_name = ''
|
||||
|
||||
# See the AICapabilities enum. Sub-classes *must* implement this to the
|
||||
# capabilities they support. Note this is independent of configuration.
|
||||
@property
|
||||
def capabilities(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def builtin_live_module(self):
|
||||
if not self.builtin_live_module_name:
|
||||
@ -845,4 +851,14 @@ class AIProviderPlugin(Plugin): # {{{
|
||||
|
||||
def customization_help(self):
|
||||
return ''
|
||||
|
||||
def config_widget(self):
|
||||
if self.builtin_live_module_name:
|
||||
return self.builtin_live_module.config_widget()
|
||||
raise NotImplementedError()
|
||||
|
||||
def save_settings(self, config_widget):
|
||||
if self.builtin_live_module_name:
|
||||
return self.builtin_live_module.save_settings(config_widget)
|
||||
raise NotImplementedError()
|
||||
# }}}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user