mirror of
https://github.com/kovidgoyal/calibre.git
synced 2026-01-04 11:10:20 -05:00
GitHub AI backend working apart from costing
This commit is contained in:
parent
e8712f26f1
commit
275c3c3918
@ -136,8 +136,8 @@ class ResultBlockReason(Enum):
|
||||
|
||||
class PromptBlocked(ValueError):
|
||||
|
||||
def __init__(self, reason: PromptBlockReason):
|
||||
super().__init__(reason.for_human)
|
||||
def __init__(self, reason: PromptBlockReason = PromptBlockReason.unknown, custom_message: str = ''):
|
||||
super().__init__(custom_message or reason.for_human)
|
||||
self.reason = reason
|
||||
|
||||
|
||||
|
||||
@ -15,6 +15,5 @@ class GitHubAI(AIProviderPlugin):
|
||||
def capabilities(self):
|
||||
from calibre.ai import AICapabilities
|
||||
return (
|
||||
AICapabilities.text_to_text | AICapabilities.text_to_image | AICapabilities.text_and_image_to_image |
|
||||
AICapabilities.embedding
|
||||
AICapabilities.text_to_text | AICapabilities.embedding
|
||||
)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
# License: GPLv3 Copyright: 2025, Kovid Goyal <kovid at kovidgoyal.net>
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Iterable, Iterator
|
||||
@ -8,7 +9,7 @@ from functools import lru_cache
|
||||
from typing import Any, NamedTuple
|
||||
from urllib.request import Request
|
||||
|
||||
from calibre.ai import AICapabilities, ChatMessage, ChatMessageType, ChatResponse, Citation, NoAPIKey, PromptBlocked, ResultBlocked, WebLink
|
||||
from calibre.ai import AICapabilities, ChatMessage, ChatMessageType, ChatResponse, NoAPIKey, PromptBlocked
|
||||
from calibre.ai.github import GitHubAI
|
||||
from calibre.ai.prefs import decode_secret, pref_for_provider
|
||||
from calibre.ai.utils import chat_with_error_handler, develop_text_chat, get_cached_resource, read_streaming_response
|
||||
@ -16,6 +17,7 @@ from calibre.constants import cache_dir
|
||||
|
||||
module_version = 1 # needed for live updates
|
||||
MODELS_URL = 'https://models.github.ai/catalog/models'
|
||||
CHAT_URL = 'https://models.github.ai/inference/chat/completions'
|
||||
API_VERSION = '2022-11-28'
|
||||
|
||||
|
||||
@ -53,62 +55,58 @@ class Model(NamedTuple):
|
||||
# See https://ai.google.dev/api/models#Model
|
||||
name: str
|
||||
id: str
|
||||
slug: str
|
||||
url: str
|
||||
description: str
|
||||
version: str
|
||||
context_length: int
|
||||
output_token_limit: int
|
||||
capabilities: AICapabilities
|
||||
family: str
|
||||
family_version: float
|
||||
name_parts: tuple[str, ...]
|
||||
thinking: bool
|
||||
publisher: str
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, x: dict[str, object]) -> 'Model':
|
||||
caps = AICapabilities.text_to_text
|
||||
mid = x['name']
|
||||
if 'embedContent' in x['supportedGenerationMethods']:
|
||||
mid = x['id']
|
||||
caps = AICapabilities.none
|
||||
if 'embedding' in x['capabilities'] or 'embeddings' in x['supported_output_modalities']:
|
||||
caps |= AICapabilities.embedding
|
||||
family, family_version = '', 0
|
||||
name_parts = mid.rpartition('/')[-1].split('-')
|
||||
if len(name_parts) > 1:
|
||||
family, fv = name_parts[:2]
|
||||
try:
|
||||
family_version = float(fv)
|
||||
except Exception:
|
||||
family = ''
|
||||
match family:
|
||||
case 'imagen':
|
||||
caps |= AICapabilities.text_to_image
|
||||
case 'gemini':
|
||||
if family_version >= 2.5:
|
||||
caps |= AICapabilities.text_and_image_to_image
|
||||
if 'tts' in name_parts:
|
||||
caps |= AICapabilities.tts
|
||||
else:
|
||||
input_has_text = x['supported_input_modalities']
|
||||
output_has_text = x['supported_output_modalities']
|
||||
if input_has_text:
|
||||
if output_has_text:
|
||||
caps |= AICapabilities.text_to_text
|
||||
return Model(
|
||||
name=x['displayName'], id=mid, description=x.get('description', ''), version=x['version'],
|
||||
context_length=int(x['inputTokenLimit']), output_token_limit=int(x['outputTokenLimit']),
|
||||
capabilities=caps, family=family, family_version=family_version, name_parts=tuple(name_parts),
|
||||
slug=mid, thinking=x.get('thinking', False)
|
||||
name=x['name'], id=mid, description=x.get('summary', ''), version=x['version'],
|
||||
context_length=int(x['limits']['max_input_tokens'] or 0), publisher=x['publisher'],
|
||||
output_token_limit=int(x['limits']['max_output_tokens'] or 0),
|
||||
capabilities=caps, url=x['html_url'], thinking='reasoning' in x['capabilities'],
|
||||
)
|
||||
|
||||
|
||||
def parse_models_list(entries: list[dict[str, Any]]) -> dict[str, Model]:
|
||||
ans = {}
|
||||
for entry in entries['models']:
|
||||
for entry in entries:
|
||||
e = Model.from_dict(entry)
|
||||
ans[e.id] = e
|
||||
return ans
|
||||
|
||||
|
||||
@lru_cache(2)
|
||||
def get_available_models() -> dict[str, 'Model']:
|
||||
cache_loc = os.path.join(cache_dir(), 'github-ai', 'models-v1.json')
|
||||
data = get_cached_resource(cache_loc, MODELS_URL, headers=headers())
|
||||
def get_available_models() -> dict[str, Model]:
|
||||
cache_loc = os.path.join(cache_dir(), 'ai', f'{GitHubAI.name}-models-v1.json')
|
||||
data = get_cached_resource(cache_loc, MODELS_URL)
|
||||
return parse_models_list(json.loads(data))
|
||||
|
||||
|
||||
def find_models_matching_name(name: str) -> Iterator[str]:
|
||||
name = name.strip().lower()
|
||||
for model in get_available_models().values():
|
||||
q = model.name.strip().lower()
|
||||
if name in q:
|
||||
yield model.id
|
||||
|
||||
|
||||
def config_widget():
|
||||
from calibre.ai.github.config import ConfigWidget
|
||||
return ConfigWidget()
|
||||
@ -124,128 +122,84 @@ def human_readable_model_name(model_id: str) -> str:
|
||||
return model_id
|
||||
|
||||
|
||||
@lru_cache(2)
|
||||
def newest_gpt_models() -> dict[str, Model]:
|
||||
high, medium, low = [], [], []
|
||||
|
||||
def get_date(model: Model) -> datetime.date:
|
||||
try:
|
||||
return datetime.date.fromisoformat(model.version)
|
||||
except Exception:
|
||||
return datetime.date(2000, 1, 1)
|
||||
|
||||
for model in get_available_models().values():
|
||||
if model.publisher == 'OpenAI' and '(preview)' not in model.name and (idp := model.id.split('/')[-1].split('-')) and 'gpt' in idp:
|
||||
which = high
|
||||
if 'mini' in model.id.split('-'):
|
||||
which = medium
|
||||
elif 'nano' in model.id.split('-'):
|
||||
which = low
|
||||
which.append(model)
|
||||
return {
|
||||
'high': sorted(high, key=get_date)[-1],
|
||||
'medium': sorted(medium, key=get_date)[-1],
|
||||
'low': sorted(low, key=get_date)[-1],
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(2)
|
||||
def model_choice_for_text() -> Model:
|
||||
m = gemini_models()
|
||||
return m.get(pref('model_strategy', 'medium')) or m['medium']
|
||||
m = newest_gpt_models()
|
||||
return m.get(pref('model_strategy', 'medium'), m['medium'])
|
||||
|
||||
|
||||
def chat_request(data: dict[str, Any], model: Model, streaming: bool = True) -> Request:
|
||||
url = f'{API_BASE_URL}/{model.slug}'
|
||||
if streaming:
|
||||
url += ':streamGenerateContent?alt=sse'
|
||||
else:
|
||||
url += ':generateContent'
|
||||
return Request(url, data=json.dumps(data).encode('utf-8'), headers=dict(headers()), method='POST')
|
||||
|
||||
|
||||
def thinking_budget(m: Model) -> int | None:
|
||||
# https://ai.google.dev/gemini-api/docs/thinking#set-budget
|
||||
if not m.thinking:
|
||||
return None
|
||||
limits = 0, 24576
|
||||
if 'pro' in m.name_parts:
|
||||
limits = 128, 32768
|
||||
elif 'lite' in m.name_parts:
|
||||
limits = 512, 24576
|
||||
match pref('reasoning_strategy', 'auto'):
|
||||
case 'auto':
|
||||
return -1
|
||||
case 'none':
|
||||
return limits[0] if 'pro' in m.name_parts else 0
|
||||
case 'low':
|
||||
return max(limits[0], int(0.2 * limits[1]))
|
||||
case 'medium':
|
||||
return max(limits[0], int(0.5 * limits[1]))
|
||||
case 'high':
|
||||
return max(limits[0], int(0.8 * limits[1]))
|
||||
return None
|
||||
def chat_request(data: dict[str, Any], model: Model) -> Request:
|
||||
data['stream'] = True
|
||||
data['stream_options'] = {'include_usage': True}
|
||||
return Request(
|
||||
CHAT_URL, data=json.dumps(data).encode('utf-8'),
|
||||
headers=dict(headers()), method='POST')
|
||||
|
||||
|
||||
def for_assistant(self: ChatMessage) -> dict[str, Any]:
|
||||
return {'text': self.query}
|
||||
if self.type not in (ChatMessageType.assistant, ChatMessageType.system, ChatMessageType.user, ChatMessageType.developer):
|
||||
raise ValueError(f'Unsupported message type: {self.type}')
|
||||
return {'role': self.type.value, 'content': self.query}
|
||||
|
||||
|
||||
def as_chat_responses(d: dict[str, Any], model: Model) -> Iterator[ChatResponse]:
|
||||
# See https://ai.google.dev/api/generate-content#generatecontentresponse
|
||||
if pf := d.get('promptFeedback'):
|
||||
if br := pf.get('blockReason'):
|
||||
yield ChatResponse(exception=PromptBlocked(block_reason(br)))
|
||||
# See https://docs.github.com/en/rest/models/inference
|
||||
content = ''
|
||||
for choice in d['choices']:
|
||||
content += choice['delta'].get('content', '')
|
||||
if (fr := choice['finish_reason']) and fr != 'stop':
|
||||
yield ChatResponse(exception=PromptBlocked(custom_message=_('Result was blocked for reason: {}').format(fr)))
|
||||
return
|
||||
grounding_chunks, grounding_supports = [], []
|
||||
for c in d['candidates']:
|
||||
has_metadata = False
|
||||
cost, currency = 0, ''
|
||||
if fr := c.get('finishReason'):
|
||||
if fr == 'STOP':
|
||||
has_metadata = True
|
||||
cost, currency = model.get_cost(d['usageMetadata'])
|
||||
else:
|
||||
yield ChatResponse(exception=ResultBlocked(result_block_reason(fr)))
|
||||
return
|
||||
content = c['content']
|
||||
if gm := c.get('groundingMetadata'):
|
||||
grounding_chunks.extend(gm['groundingChunks'])
|
||||
grounding_supports.extend(gm['groundingSupports'])
|
||||
citations, web_links = [], []
|
||||
if has_metadata:
|
||||
for x in grounding_chunks:
|
||||
if w := x.get('web'):
|
||||
web_links.append(WebLink(**w))
|
||||
else:
|
||||
web_links.append(WebLink())
|
||||
|
||||
for s in grounding_supports:
|
||||
if links := tuple(i for i in s['groundingChunkIndices'] if web_links[i]):
|
||||
seg = s['segment']
|
||||
citations.append(Citation(
|
||||
links, start_offset=seg.get('startIndex', 0), end_offset=seg.get('endIndex', 0), text=seg.get('text', '')))
|
||||
role = ChatMessageType.user if 'user' == content.get('role') else ChatMessageType.assistant
|
||||
content_parts = []
|
||||
reasoning_parts = []
|
||||
reasoning_details = []
|
||||
for part in content['parts']:
|
||||
if text := part.get('text'):
|
||||
(reasoning_parts if part.get('thought') else content_parts).append(text)
|
||||
if ts := part.get('thoughtSignature'):
|
||||
reasoning_details.append({'signature': ts})
|
||||
has_metadata = False
|
||||
if u := d.get('usage'):
|
||||
u # TODO: implement costing
|
||||
has_metadata = True
|
||||
if has_metadata or content:
|
||||
yield ChatResponse(
|
||||
type=role, content=''.join(content_parts), reasoning=''.join(reasoning_parts),
|
||||
reasoning_details=tuple(reasoning_details), has_metadata=has_metadata, model=model.id,
|
||||
cost=cost, plugin_name=GitHubAI.name, currency=currency, citations=citations, web_links=web_links,
|
||||
)
|
||||
type=ChatMessageType.assistant, content=content, has_metadata=has_metadata, model=model.id, plugin_name=GitHubAI.name)
|
||||
|
||||
|
||||
def text_chat_implementation(messages: Iterable[ChatMessage], use_model: str = '') -> Iterator[ChatResponse]:
|
||||
# See https://ai.google.dev/gemini-api/docs/text-generation
|
||||
# https://docs.github.com/en/rest/models/inference
|
||||
if use_model:
|
||||
model = get_available_models()[use_model]
|
||||
else:
|
||||
model = model_choice_for_text()
|
||||
contents = []
|
||||
system_instructions = []
|
||||
for m in messages:
|
||||
d = system_instructions if m.type is ChatMessageType.system else contents
|
||||
d.append(for_assistant(m))
|
||||
data = {
|
||||
# See https://ai.google.dev/api/generate-content#v1beta.GenerationConfig
|
||||
'generationConfig': {
|
||||
'thinkingConfig': {
|
||||
'includeThoughts': True,
|
||||
},
|
||||
},
|
||||
'model': model.id,
|
||||
'messages': [for_assistant(m) for m in messages],
|
||||
}
|
||||
if (tb := thinking_budget(model)) is not None:
|
||||
data['generationConfig']['thinkingConfig']['thinkingBudget'] = tb
|
||||
if system_instructions:
|
||||
data['system_instruction'] = {'parts': system_instructions}
|
||||
if contents:
|
||||
data['contents'] = [{'parts': contents}]
|
||||
if pref('allow_web_searches', True):
|
||||
data['tools'] = [{'google_search': {}}]
|
||||
rq = chat_request(data, model)
|
||||
|
||||
for datum in read_streaming_response(rq, GitHubAI.name):
|
||||
yield from as_chat_responses(datum, model)
|
||||
for res in as_chat_responses(datum, model):
|
||||
yield res
|
||||
if res.exception:
|
||||
break
|
||||
|
||||
|
||||
def text_chat(messages: Iterable[ChatMessage], use_model: str = '') -> Iterator[ChatResponse]:
|
||||
@ -254,9 +208,8 @@ def text_chat(messages: Iterable[ChatMessage], use_model: str = '') -> Iterator[
|
||||
|
||||
def develop(use_model: str = '', msg: str = '') -> None:
|
||||
# calibre-debug -c 'from calibre.ai.github.backend import develop; develop()'
|
||||
print('\n'.join(f'{k}:{m.id}' for k, m in gemini_models().items()))
|
||||
m = (ChatMessage(msg),) if msg else ()
|
||||
develop_text_chat(text_chat, ('models/' + use_model) if use_model else '', messages=m)
|
||||
develop_text_chat(text_chat, use_model, messages=m)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -1,14 +1,14 @@
|
||||
#!/usr/bin/env python
|
||||
# License: GPLv3 Copyright: 2025, Kovid Goyal <kovid at kovidgoyal.net>
|
||||
|
||||
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
|
||||
from qt.core import QComboBox, QFormLayout, QLabel, QLineEdit, QWidget
|
||||
from qt.core import QComboBox, QFormLayout, QHBoxLayout, QLabel, QLineEdit, QWidget
|
||||
|
||||
from calibre.ai.github import GitHubAI
|
||||
from calibre.ai.prefs import decode_secret, encode_secret, pref_for_provider, set_prefs_for_provider
|
||||
from calibre.ai.utils import configure
|
||||
from calibre.ai.utils import configure, plugin_for_name
|
||||
from calibre.gui2 import error_dialog
|
||||
|
||||
pref = partial(pref_for_provider, GitHubAI.name)
|
||||
@ -50,6 +50,23 @@ class ConfigWidget(QWidget):
|
||||
'The model choice strategy controls how a model to query is chosen. Cheaper and faster models give lower'
|
||||
' quality results.'
|
||||
))
|
||||
self.text_model_edit = lm = QLineEdit(self)
|
||||
lm.setClearButtonEnabled(True)
|
||||
lm.setToolTip(_(
|
||||
'Enter a name of the model to use for text based tasks.'
|
||||
' If not specified, one is chosen automatically.'
|
||||
))
|
||||
lm.setPlaceholderText(_('Optionally, enter name of model to use'))
|
||||
self.browse_label = la = QLabel(f'<a href="https://github.com/marketplace?type=models">{_("Browse")}</a>')
|
||||
tm = QWidget()
|
||||
la.setOpenExternalLinks(True)
|
||||
h = QHBoxLayout(tm)
|
||||
h.setContentsMargins(0, 0, 0, 0)
|
||||
h.addWidget(lm), h.addWidget(la)
|
||||
l.addRow(_('Model for &text tasks:'), tm)
|
||||
self.initial_text_model = pm = pref('text_model') or {'name': '', 'id': ''}
|
||||
if pm:
|
||||
lm.setText(pm['name'])
|
||||
|
||||
@property
|
||||
def api_key(self) -> str:
|
||||
@ -61,19 +78,37 @@ class ConfigWidget(QWidget):
|
||||
|
||||
@property
|
||||
def settings(self) -> dict[str, str]:
|
||||
return {
|
||||
name = self.text_model_edit.text().strip()
|
||||
ans = {
|
||||
'api_key': encode_secret(self.api_key), 'model_choice_strategy': self.model_choice_strategy,
|
||||
}
|
||||
if name:
|
||||
ans['text_model'] = {'name': name, 'id': self.model_ids_for_name(name)[0]}
|
||||
return ans
|
||||
|
||||
@property
|
||||
def is_ready_for_use(self) -> bool:
|
||||
return bool(self.api_key)
|
||||
|
||||
def model_ids_for_name(self, name: str) -> Sequence[str]:
|
||||
if name and name == self.initial_text_model['name']:
|
||||
return (self.initial_text_model['id'],)
|
||||
plugin = plugin_for_name(GitHubAI.name)
|
||||
return tuple(plugin.builtin_live_module.find_models_matching_name(name))
|
||||
|
||||
def validate(self) -> bool:
|
||||
if self.is_ready_for_use:
|
||||
return True
|
||||
error_dialog(self, _('No API key'), _('You must supply a Personal access token to use GitHub AI.'), show=True)
|
||||
return False
|
||||
if not self.is_ready_for_use:
|
||||
error_dialog(self, _('No API key'), _('You must supply a Personal access token to use GitHub AI.'), show=True)
|
||||
return False
|
||||
if (name := self.text_model_edit.text().strip()) and name:
|
||||
num = len(self.model_ids_for_name(name))
|
||||
if num == 0:
|
||||
error_dialog(self, _('No matching model'), _('No model named {} found on GitHub').format(name), show=True)
|
||||
return False
|
||||
if num > 1:
|
||||
error_dialog(self, _('Ambiguous model name'), _('The name {} matches more than one model on GitHub').format(name), show=True)
|
||||
return False
|
||||
return True
|
||||
|
||||
def save_settings(self):
|
||||
set_prefs_for_provider(GitHubAI.name, self.settings)
|
||||
|
||||
@ -185,7 +185,7 @@ def parse_models_list(entries: list[dict[str, Any]]) -> dict[str, Model]:
|
||||
@lru_cache(2)
|
||||
def get_available_models() -> dict[str, 'Model']:
|
||||
api_key = decoded_api_key()
|
||||
cache_loc = os.path.join(cache_dir(), 'google-ai', 'models-v1.json')
|
||||
cache_loc = os.path.join(cache_dir(), 'ai', f'{GoogleAI.name}-models-v1.json')
|
||||
data = get_cached_resource(cache_loc, MODELS_URL, headers=(('X-goog-api-key', api_key),))
|
||||
return parse_models_list(json.loads(data))
|
||||
|
||||
@ -377,7 +377,10 @@ def text_chat_implementation(messages: Iterable[ChatMessage], use_model: str = '
|
||||
rq = chat_request(data, model)
|
||||
|
||||
for datum in read_streaming_response(rq, GoogleAI.name):
|
||||
yield from as_chat_responses(datum, model)
|
||||
for res in as_chat_responses(datum, model):
|
||||
yield res
|
||||
if res.exception:
|
||||
break
|
||||
|
||||
|
||||
def text_chat(messages: Iterable[ChatMessage], use_model: str = '') -> Iterator[ChatResponse]:
|
||||
|
||||
@ -26,7 +26,7 @@ def pref(key: str, defval: Any = None) -> Any:
|
||||
|
||||
@lru_cache(2)
|
||||
def get_available_models() -> dict[str, 'Model']:
|
||||
cache_loc = os.path.join(cache_dir(), 'openrouter', 'models-v1.json')
|
||||
cache_loc = os.path.join(cache_dir(), 'ai', f'{OpenRouterAI.name}-models-v1.json')
|
||||
data = get_cached_resource(cache_loc, MODELS_URL)
|
||||
return parse_models_list(json.loads(data))
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@ from urllib.request import ProxyHandler, Request, build_opener
|
||||
from calibre import get_proxies
|
||||
from calibre.ai import ChatMessage, ChatMessageType, ChatResponse, Citation, WebLink
|
||||
from calibre.constants import __version__
|
||||
from calibre.customize import AIProviderPlugin
|
||||
from calibre.customize.ui import available_ai_provider_plugins
|
||||
|
||||
|
||||
@ -286,17 +287,20 @@ def develop_text_chat(
|
||||
pprint(msg)
|
||||
|
||||
|
||||
def plugin_for_name(plugin_name: str) -> AIProviderPlugin:
|
||||
for plugin in available_ai_provider_plugins():
|
||||
if plugin.name == plugin_name:
|
||||
return plugin
|
||||
raise KeyError(f'No plugin named {plugin_name} is available')
|
||||
|
||||
|
||||
def configure(plugin_name: str, parent: Any = None) -> None:
|
||||
from qt.core import QDialog, QDialogButtonBox, QVBoxLayout
|
||||
|
||||
from calibre.gui2 import ensure_app
|
||||
ensure_app(headless=False)
|
||||
for plugin in available_ai_provider_plugins():
|
||||
if plugin.name == plugin_name:
|
||||
cw = plugin.config_widget()
|
||||
break
|
||||
else:
|
||||
raise KeyError(f'No plugin named: {plugin_name}')
|
||||
plugin = plugin_for_name(plugin_name)
|
||||
cw = plugin.config_widget()
|
||||
class D(QDialog):
|
||||
def accept(self):
|
||||
if not cw.validate():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user