mirror of
https://github.com/kovidgoyal/calibre.git
synced 2025-12-15 09:35:02 -05:00
Start work on Ollama backend
This commit is contained in:
parent
2be1c4d276
commit
f9914da65f
20
src/calibre/ai/ollama/__init__.py
Normal file
20
src/calibre/ai/ollama/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
# License: GPLv3 Copyright: 2025, Kovid Goyal <kovid at kovidgoyal.net>
|
||||
|
||||
from calibre.customize import AIProviderPlugin
|
||||
|
||||
|
||||
class OllamaAI(AIProviderPlugin):
|
||||
DEFAULT_URL = 'http://localhost:11434'
|
||||
name = 'OllamaAI'
|
||||
version = (1, 0, 0)
|
||||
description = _('AI services from Ollama, when you wan tto run AI models yourself rather than rely on a third party provider.')
|
||||
author = 'Kovid Goyal'
|
||||
builtin_live_module_name = 'calibre.ai.ollama.backend'
|
||||
|
||||
@property
|
||||
def capabilities(self):
|
||||
from calibre.ai import AICapabilities
|
||||
return (
|
||||
AICapabilities.text_to_text
|
||||
)
|
||||
160
src/calibre/ai/ollama/backend.py
Normal file
160
src/calibre/ai/ollama/backend.py
Normal file
@ -0,0 +1,160 @@
|
||||
#!/usr/bin/env python
|
||||
# License: GPLv3 Copyright: 2025, Kovid Goyal <kovid at kovidgoyal.net>
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import posixpath
|
||||
from collections.abc import Iterable, Iterator, Sequence
|
||||
from functools import lru_cache
|
||||
from typing import Any, NamedTuple
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
from urllib.request import Request
|
||||
|
||||
from calibre.ai import ChatMessage, ChatMessageType, ChatResponse, PromptBlocked
|
||||
from calibre.ai.ollama import OllamaAI
|
||||
from calibre.ai.prefs import pref_for_provider
|
||||
from calibre.ai.utils import chat_with_error_handler, develop_text_chat, download_data, read_streaming_response
|
||||
|
||||
module_version = 1 # needed for live updates
|
||||
|
||||
|
||||
def pref(key: str, defval: Any = None) -> Any:
|
||||
return pref_for_provider(OllamaAI.name, key, defval)
|
||||
|
||||
|
||||
def is_ready_for_use() -> bool:
|
||||
return bool(pref('text_model'))
|
||||
|
||||
|
||||
def headers() -> tuple[tuple[str, str]]:
|
||||
return (
|
||||
('Content-Type', 'application/json'),
|
||||
)
|
||||
|
||||
|
||||
class Model(NamedTuple):
|
||||
# See https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
|
||||
name: str
|
||||
id: str
|
||||
family: str
|
||||
families: Sequence[str]
|
||||
modified_at: datetime.datetime
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, x: dict[str, object]) -> 'Model':
|
||||
mid = x['model']
|
||||
d = x.get('details', {})
|
||||
return Model(
|
||||
name=x['name'], id=mid, family=d.get('family', ''), families=d.get('families', ()),
|
||||
modified_at=datetime.datetime.fromisoformat(x['modified_at'])
|
||||
)
|
||||
|
||||
|
||||
def parse_models_list(entries: list[dict[str, Any]]) -> dict[str, Model]:
|
||||
ans = {}
|
||||
for entry in entries:
|
||||
e = Model.from_dict(entry)
|
||||
ans[e.id] = e
|
||||
return ans
|
||||
|
||||
|
||||
def api_url(path: str = '') -> str:
|
||||
ans = pref('api_url') or OllamaAI.DEFAULT_URL
|
||||
purl = urlparse(ans)
|
||||
base_path = purl.path or '/'
|
||||
if path:
|
||||
path = posixpath.join(base_path, path)
|
||||
purl = purl._replace(path=path)
|
||||
return urlunparse(purl)
|
||||
|
||||
|
||||
@lru_cache(2)
|
||||
def get_available_models() -> dict[str, Model]:
|
||||
return parse_models_list(json.loads(download_data(api_url('api/tags')))['models'])
|
||||
|
||||
|
||||
def does_model_exist_locally(model_id: str) -> bool:
|
||||
return model_id in get_available_models()
|
||||
|
||||
|
||||
def config_widget():
|
||||
from calibre.ai.ollama.config import ConfigWidget
|
||||
return ConfigWidget()
|
||||
|
||||
|
||||
def save_settings(config_widget):
|
||||
config_widget.save_settings()
|
||||
|
||||
|
||||
def human_readable_model_name(model_id: str) -> str:
|
||||
if m := get_available_models().get(model_id):
|
||||
model_id = m.name
|
||||
return model_id
|
||||
|
||||
|
||||
@lru_cache(2)
|
||||
def model_choice_for_text() -> Model:
|
||||
return get_available_models()[pref('text_model')]
|
||||
|
||||
|
||||
def chat_request(data: dict[str, Any], model: Model) -> Request:
|
||||
data['stream'] = True
|
||||
data['stream_options'] = {'include_usage': True}
|
||||
return Request(
|
||||
api_url('api/chat'), data=json.dumps(data).encode('utf-8'),
|
||||
headers=dict(headers()), method='POST')
|
||||
|
||||
|
||||
def for_assistant(self: ChatMessage) -> dict[str, Any]:
|
||||
if self.type not in (ChatMessageType.assistant, ChatMessageType.system, ChatMessageType.user, ChatMessageType.developer):
|
||||
raise ValueError(f'Unsupported message type: {self.type}')
|
||||
return {'role': self.type.value, 'content': self.query}
|
||||
|
||||
|
||||
def as_chat_responses(d: dict[str, Any], model: Model) -> Iterator[ChatResponse]:
|
||||
# See https://docs.github.com/en/rest/models/inference
|
||||
content = ''
|
||||
for choice in d['choices']:
|
||||
content += choice['delta'].get('content', '')
|
||||
if (fr := choice['finish_reason']) and fr != 'stop':
|
||||
yield ChatResponse(exception=PromptBlocked(custom_message=_('Result was blocked for reason: {}').format(fr)))
|
||||
return
|
||||
has_metadata = False
|
||||
if u := d.get('usage'):
|
||||
u # TODO: implement costing
|
||||
has_metadata = True
|
||||
if has_metadata or content:
|
||||
yield ChatResponse(
|
||||
type=ChatMessageType.assistant, content=content, has_metadata=has_metadata, model=model.id, plugin_name=OllamaAI.name)
|
||||
|
||||
|
||||
def text_chat_implementation(messages: Iterable[ChatMessage], use_model: str = '') -> Iterator[ChatResponse]:
|
||||
# https://docs.github.com/en/rest/models/inference
|
||||
if use_model:
|
||||
model = get_available_models()[use_model]
|
||||
else:
|
||||
model = model_choice_for_text()
|
||||
data = {
|
||||
'model': model.id,
|
||||
'messages': [for_assistant(m) for m in messages],
|
||||
}
|
||||
rq = chat_request(data, model)
|
||||
for datum in read_streaming_response(rq, OllamaAI.name):
|
||||
for res in as_chat_responses(datum, model):
|
||||
yield res
|
||||
if res.exception:
|
||||
break
|
||||
|
||||
|
||||
def text_chat(messages: Iterable[ChatMessage], use_model: str = '') -> Iterator[ChatResponse]:
|
||||
yield from chat_with_error_handler(text_chat_implementation(messages, use_model))
|
||||
|
||||
|
||||
def develop(use_model: str = '', msg: str = '') -> None:
|
||||
# calibre-debug -c 'from calibre.ai.ollama.backend import develop; develop()'
|
||||
m = (ChatMessage(msg),) if msg else ()
|
||||
develop_text_chat(text_chat, use_model, messages=m)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
develop()
|
||||
96
src/calibre/ai/ollama/config.py
Normal file
96
src/calibre/ai/ollama/config.py
Normal file
@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python
|
||||
# License: GPLv3 Copyright: 2025, Kovid Goyal <kovid at kovidgoyal.net>
|
||||
|
||||
from functools import partial
|
||||
|
||||
from qt.core import QFormLayout, QLabel, QLineEdit, QWidget
|
||||
|
||||
from calibre.ai.ollama import OllamaAI
|
||||
from calibre.ai.prefs import pref_for_provider, set_prefs_for_provider
|
||||
from calibre.ai.utils import configure, plugin_for_name
|
||||
from calibre.gui2 import error_dialog
|
||||
|
||||
pref = partial(pref_for_provider, OllamaAI.name)
|
||||
|
||||
|
||||
class ConfigWidget(QWidget):
|
||||
|
||||
def __init__(self, parent: QWidget | None = None):
|
||||
super().__init__(parent)
|
||||
l = QFormLayout(self)
|
||||
l.setFieldGrowthPolicy(QFormLayout.FieldGrowthPolicy.ExpandingFieldsGrow)
|
||||
la = QLabel('<p>'+_(
|
||||
'Ollama allows you to run AI models locally on your own hardware. Once you have it running and properly'
|
||||
' setup, fill in the fields below to have calibre use it as the AI provider.'
|
||||
))
|
||||
la.setWordWrap(True)
|
||||
la.setOpenExternalLinks(True)
|
||||
l.addRow(la)
|
||||
|
||||
self.api_url_edit = a = QLineEdit()
|
||||
a.setPlaceholderText(_('The Ollama URL, defaults to {}').format(OllamaAI.DEFAULT_URL))
|
||||
a.setToolTip(_('Enter the URL of the machine running your Ollama server, for example: {}').format(
|
||||
'https://my-ollama-server.com:11434'))
|
||||
self.text_model_edit = lm = QLineEdit(self)
|
||||
l.addRow(_('Ollama &URL:'), a)
|
||||
lm.setClearButtonEnabled(True)
|
||||
lm.setToolTip(_(
|
||||
'Enter the name of the model to use for text based tasks.'
|
||||
))
|
||||
lm.setPlaceholderText(_('Enter name of model to use'))
|
||||
l.addRow(_('Model for &text tasks:'), lm)
|
||||
lm.setText(pref('text_model') or '')
|
||||
|
||||
def does_model_exist_locally(self, model_name: str) -> bool:
|
||||
if not model_name:
|
||||
return False
|
||||
plugin = plugin_for_name(OllamaAI.name)
|
||||
return plugin.builtin_live_module.does_model_exist_locally(model_name)
|
||||
|
||||
def available_models(self) -> list[str]:
|
||||
plugin = plugin_for_name(OllamaAI.name)
|
||||
return sorted(plugin.builtin_live_module.get_available_models(), key=lambda x: x.lower())
|
||||
|
||||
@property
|
||||
def text_model(self) -> str:
|
||||
return self.text_model_edit.text().strip()
|
||||
|
||||
@property
|
||||
def settings(self) -> dict[str, str]:
|
||||
ans = {
|
||||
'text_model': self.text_model,
|
||||
}
|
||||
url = self.api_url_edit.text().strip()
|
||||
if url:
|
||||
ans['api_url'] = url
|
||||
return ans
|
||||
|
||||
@property
|
||||
def is_ready_for_use(self) -> bool:
|
||||
return bool(self.text_model)
|
||||
|
||||
def validate(self) -> bool:
|
||||
if not self.text_model:
|
||||
error_dialog(self, _('No model specified'), _('You specify a model to use for text based tasks.'), show=True)
|
||||
return False
|
||||
if not self.does_model_exist_locally(self.text_model):
|
||||
try:
|
||||
avail = self.available_models()
|
||||
except Exception:
|
||||
import traceback
|
||||
det_msg = _('Failed to get list of available models with error:') + '\n' + traceback.format_exc()
|
||||
else:
|
||||
det_msg = _('Available models:') + '\n' + '\n'.join(avail)
|
||||
|
||||
error_dialog(self, _('No matching model'), _(
|
||||
'No model named {} found in Ollama. Click "Show details" to see a list of available models.').format(
|
||||
self.text_model), show=True, det_msg=det_msg)
|
||||
return False
|
||||
return True
|
||||
|
||||
def save_settings(self):
|
||||
set_prefs_for_provider(OllamaAI.name, self.settings)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
configure(OllamaAI.name)
|
||||
@ -6,6 +6,7 @@ import os
|
||||
|
||||
from calibre.ai.github import GitHubAI
|
||||
from calibre.ai.google import GoogleAI
|
||||
from calibre.ai.ollama import OllamaAI
|
||||
from calibre.ai.open_router import OpenRouterAI
|
||||
from calibre.constants import numeric_version
|
||||
from calibre.customize import FileTypePlugin, InterfaceActionBase, MetadataReaderPlugin, MetadataWriterPlugin, PreferencesPlugin, StoreBase
|
||||
@ -1979,7 +1980,7 @@ plugins += [
|
||||
|
||||
# }}}
|
||||
|
||||
plugins.extend((OpenRouterAI, GoogleAI, GitHubAI,))
|
||||
plugins.extend((OpenRouterAI, GoogleAI, GitHubAI, OllamaAI))
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Test load speed
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user