mirror of
https://github.com/kovidgoyal/calibre.git
synced 2025-12-30 16:50:20 -05:00
More work on integration new AI backend into existing viewer code
This commit is contained in:
parent
dc7977db89
commit
c8acb5d2d3
@ -13,7 +13,7 @@ from functools import lru_cache
|
||||
from pprint import pprint
|
||||
from threading import Thread
|
||||
from typing import Any, NamedTuple
|
||||
from urllib.error import HTTPError
|
||||
from urllib.error import HTTPError, URLError
|
||||
from urllib.request import ProxyHandler, Request, build_opener
|
||||
|
||||
from calibre import browser, get_proxies
|
||||
@ -346,7 +346,14 @@ def text_chat(messages: Iterable[ChatMessage], use_model: str = '') -> Iterator[
|
||||
details = e.fp.read().decode()
|
||||
except Exception:
|
||||
details = ''
|
||||
try:
|
||||
error_json = json.loads(details)
|
||||
details = error_json.get('error', {}).get('message', details)
|
||||
except Exception:
|
||||
pass
|
||||
yield ChatResponse(exception=e, error_details=details)
|
||||
except URLError as e:
|
||||
yield ChatResponse(exception=e, error_details=f'Network error: {e.reason}')
|
||||
except Exception as e:
|
||||
import traceback
|
||||
yield ChatResponse(exception=e, error_details=traceback.format_exc())
|
||||
|
||||
@ -1,15 +1,11 @@
|
||||
# License: GPL v3 Copyright: 2025, Amir Tehrani and Kovid Goyal
|
||||
|
||||
import json
|
||||
import textwrap
|
||||
from collections.abc import Iterator
|
||||
from functools import lru_cache, partial
|
||||
from itertools import count
|
||||
from threading import Thread
|
||||
from typing import NamedTuple
|
||||
from urllib import request
|
||||
from urllib.error import HTTPError, URLError
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from qt.core import (
|
||||
QAbstractItemView,
|
||||
@ -33,14 +29,17 @@ from qt.core import (
|
||||
Qt,
|
||||
QTabWidget,
|
||||
QTextBrowser,
|
||||
QUrl,
|
||||
QVBoxLayout,
|
||||
QWidget,
|
||||
pyqtSignal,
|
||||
)
|
||||
|
||||
from calibre.ai import AICapabilities, ChatMessage, ChatMessageType
|
||||
from calibre.ai import AICapabilities, ChatMessage, ChatMessageType, ChatResponse
|
||||
from calibre.ai.config import ConfigureAI
|
||||
from calibre.ai.prefs import plugin_for_purpose
|
||||
from calibre.ai.utils import StreamedResponseAccumulator
|
||||
from calibre.customize import AIProviderPlugin
|
||||
from calibre.ebooks.metadata import authors_to_string
|
||||
from calibre.gui2 import Application, error_dialog
|
||||
from calibre.gui2.dialogs.confirm_delete import confirm
|
||||
@ -48,69 +47,7 @@ from calibre.gui2.viewer.config import vprefs
|
||||
from calibre.gui2.viewer.highlights import HighlightColorCombo
|
||||
from calibre.gui2.widgets2 import Dialog
|
||||
from calibre.utils.icu import primary_sort_key
|
||||
|
||||
# --- Backend Abstraction & Cost Data ---
|
||||
API_PROVIDERS = {
|
||||
'openrouter': {
|
||||
'url': 'https://openrouter.ai/api/v1/chat/completions',
|
||||
'headers': lambda api_key: {
|
||||
'Authorization': f'Bearer {api_key}',
|
||||
'Content-Type': 'application/json',
|
||||
'HTTP-Referer': 'https://calibre-ebook.com',
|
||||
'X-Title': 'calibre'
|
||||
},
|
||||
'payload': lambda model_id, messages: {
|
||||
'model': model_id,
|
||||
'messages': messages
|
||||
},
|
||||
'parse_response': lambda r_json: (
|
||||
r_json['choices'][0]['message']['content'],
|
||||
r_json.get('usage', {'prompt_tokens': 0, 'completion_tokens': 0})
|
||||
)
|
||||
}
|
||||
}
|
||||
# --- End Backend Abstraction ---
|
||||
|
||||
|
||||
class LLMAPICall(Thread):
|
||||
def __init__(self, conversation_history, signal_emitter):
|
||||
super().__init__(daemon=True)
|
||||
self.conversation_history = conversation_history
|
||||
self.signal_emitter = signal_emitter
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
url = self.provider_config['url']
|
||||
headers = self.provider_config['headers'](self.api_key)
|
||||
payload = self.provider_config['payload'](self.model_id, self.conversation_history)
|
||||
|
||||
encoded_data = json.dumps(payload).encode('utf-8')
|
||||
req = request.Request(url, data=encoded_data, headers=headers, method='POST')
|
||||
|
||||
with request.urlopen(req, timeout=90) as response:
|
||||
response_data = response.read().decode('utf-8')
|
||||
response_json = json.loads(response_data)
|
||||
|
||||
if 'error' in response_json:
|
||||
raise Exception(response_json['error'].get('message', 'Unknown API error'))
|
||||
if not response_json.get('choices'):
|
||||
raise Exception('API response did not contain any choices.')
|
||||
|
||||
result_text, usage_data = self.provider_config['parse_response'](response_json)
|
||||
self.signal_emitter.emit(result_text, usage_data)
|
||||
|
||||
except HTTPError as e:
|
||||
error_body = e.read().decode('utf-8')
|
||||
try:
|
||||
error_json = json.loads(error_body)
|
||||
msg = error_json.get('error', {}).get('message', error_body)
|
||||
except json.JSONDecodeError:
|
||||
msg = error_body
|
||||
self.signal_emitter.emit(f"<p style='color:red;'><b>API Error ({e.code}):</b> {msg}</p>", {})
|
||||
except URLError as e:
|
||||
self.signal_emitter.emit(f"<p style='color:red;'><b>Network Error:</b> {e.reason}</p>", {})
|
||||
except Exception as e:
|
||||
self.signal_emitter.emit(f"<p style='color:red;'><b>An unexpected error occurred:</b> {e}</p>", {})
|
||||
from calibre.utils.short_uuid import uuid4
|
||||
|
||||
|
||||
class Action(NamedTuple):
|
||||
@ -128,12 +65,12 @@ class Action(NamedTuple):
|
||||
@lru_cache(2)
|
||||
def default_actions() -> tuple[Action, ...]:
|
||||
return (
|
||||
Action('summarize', _('Summarize'), 'Provide a concise summary of the following text.'),
|
||||
Action('explain', _('Explain'), 'Explain the following text in simple, easy-to-understand terms.'),
|
||||
Action('points', _('Key points'), 'Extract the key points from the following text as a bulleted list.'),
|
||||
Action('define', _('Define'), 'Identify and define any technical or complex terms in the following text.'),
|
||||
Action('grammar', _('Correct grammar'), 'Correct any grammatical errors in the following text and provide the corrected version.'),
|
||||
Action('english', _('As English'), 'Translate the following text into English.'),
|
||||
Action('summarize', _('Summarize'), 'Provide a concise summary of the selected text.'),
|
||||
Action('explain', _('Explain'), 'Explain the selected text in simple, easy-to-understand terms.'),
|
||||
Action('points', _('Key points'), 'Extract the key points from the selected text as a bulleted list.'),
|
||||
Action('define', _('Define'), 'Identify and define any technical or complex terms in the selected text.'),
|
||||
Action('grammar', _('Correct grammar'), 'Correct any grammatical errors in the selected text and provide the corrected version.'),
|
||||
Action('english', _('As English'), 'Translate the selected text into English.'),
|
||||
)
|
||||
|
||||
|
||||
@ -153,8 +90,10 @@ def current_actions(include_disabled=False):
|
||||
class ConversationHistory:
|
||||
|
||||
def __init__(self, conversation_text: str = ''):
|
||||
self.accumulator = StreamedResponseAccumulator()
|
||||
self.items: list[ChatMessage] = []
|
||||
self.conversation_text: str = conversation_text
|
||||
self.model_used = ''
|
||||
|
||||
def __iter__(self) -> Iterator[ChatMessage]:
|
||||
return iter(self.items)
|
||||
@ -170,6 +109,7 @@ class ConversationHistory:
|
||||
|
||||
def copy(self, upto: int | None = None) -> 'ConversationHistory':
|
||||
ans = ConversationHistory(self.conversation_text)
|
||||
ans.model_used = self.model_used
|
||||
if upto is None:
|
||||
ans.items = list(self.items)
|
||||
else:
|
||||
@ -223,12 +163,13 @@ def format_llm_note(conversation: ConversationHistory) -> str:
|
||||
|
||||
|
||||
class LLMPanel(QWidget):
|
||||
response_received = pyqtSignal(str, dict)
|
||||
response_received = pyqtSignal(int, object)
|
||||
add_note_requested = pyqtSignal(dict)
|
||||
_SAVE_ACTION_URL_SCHEME = 'calibre-llm-action'
|
||||
|
||||
def __init__(self, parent=None, viewer=None, lookup_widget=None):
|
||||
super().__init__(parent)
|
||||
self.save_note_hostname = f'{uuid4().lower()}.calibre'
|
||||
self.configure_ai_hostname = f'{uuid4().lower()}.calibre'
|
||||
self.viewer = viewer
|
||||
self.counter = count(start=1)
|
||||
self.lookup_widget = lookup_widget
|
||||
@ -262,7 +203,7 @@ class LLMPanel(QWidget):
|
||||
self.layout.addWidget(custom_prompt_group)
|
||||
|
||||
self.result_display = QTextBrowser(self)
|
||||
self.result_display.setOpenExternalLinks(False)
|
||||
self.result_display.setOpenLinks(False)
|
||||
self.result_display.setMinimumHeight(150)
|
||||
self.result_display.anchorClicked.connect(self._on_chat_link_clicked)
|
||||
self.layout.addWidget(self.result_display)
|
||||
@ -293,7 +234,7 @@ class LLMPanel(QWidget):
|
||||
|
||||
self.custom_prompt_button.clicked.connect(self.run_custom_prompt)
|
||||
self.custom_prompt_edit.returnPressed.connect(self.run_custom_prompt)
|
||||
self.response_received.connect(self.show_response)
|
||||
self.response_received.connect(self.on_response_from_ai, type=Qt.ConnectionType.QueuedConnection)
|
||||
self.settings_button.clicked.connect(self.show_settings)
|
||||
self.show_initial_message()
|
||||
|
||||
@ -326,18 +267,21 @@ class LLMPanel(QWidget):
|
||||
dialog.actions_updated.connect(self.rebuild_actions_ui)
|
||||
dialog.exec()
|
||||
|
||||
@property
|
||||
def ai_provider_plugin(self) -> AIProviderPlugin | None:
|
||||
return plugin_for_purpose(AICapabilities.text_to_text)
|
||||
|
||||
@property
|
||||
def is_ready_for_use(self) -> bool:
|
||||
p = plugin_for_purpose(AICapabilities.text_to_text)
|
||||
p = self.ai_provider_plugin
|
||||
return p is not None and p.is_ready_for_use
|
||||
|
||||
def show_initial_message(self):
|
||||
self.save_note_button.setEnabled(False)
|
||||
if not self.is_ready_for_use:
|
||||
self.show_response('<p>' + _(
|
||||
'Please configure an AI provider by clicking the <b>Settings</b> button below.'), is_error_or_status=True)
|
||||
self.show_html(f'<p><a href="http://{self.configure_ai_hostname}">{_("First, configure an AI provider")}')
|
||||
else:
|
||||
self.show_response(_('Select text in the book to begin.'), is_error_or_status=True)
|
||||
self.show_html('<p>' + _('Select text in the book to begin.'))
|
||||
|
||||
def update_with_text(self, text, highlight_data, is_read_only_view=False):
|
||||
new_uuid = highlight_data.get('uuid') if highlight_data else None
|
||||
@ -365,9 +309,9 @@ class LLMPanel(QWidget):
|
||||
self.conversation_history = ConversationHistory()
|
||||
|
||||
if text:
|
||||
self.show_response(f"<b>{_('Selected')}:</b><br><i>'{text[:200]}...'</i>", is_error_or_status=True)
|
||||
self.show_html(f"<b>{_('Selected')}:</b><br><i>'{text[:200]}...'</i>")
|
||||
else:
|
||||
self.show_response(_('<b>Ready.</b> Ask a follow-up question.'), is_error_or_status=True)
|
||||
self.show_html(_('<b>Ready.</b> Ask a follow-up question.'))
|
||||
|
||||
if self.latched_highlight_uuid:
|
||||
self.save_note_button.setToolTip(_("Append this response to the existing highlight's note"))
|
||||
@ -413,7 +357,7 @@ class LLMPanel(QWidget):
|
||||
else:
|
||||
bgcolor = assistant_bgcolor
|
||||
label = _('Assistant')
|
||||
save_button_href = f'http://{self._SAVE_ACTION_URL_SCHEME}/save?index={i}'
|
||||
save_button_href = f'http://{self.save_note_hostname}/{i}'
|
||||
html_output += f'''
|
||||
<table style="{base_table_style}" bgcolor="{bgcolor}" cellspacing="0" cellpadding="0">
|
||||
<tr>
|
||||
@ -429,11 +373,11 @@ class LLMPanel(QWidget):
|
||||
|
||||
def start_api_call(self, action_prompt):
|
||||
if not self.is_ready_for_use:
|
||||
self.show_response(f"<p style='color:orange'><b>{_('AI provider not configured')}</b> {_(
|
||||
'Click the <b>Settings</b> button to configure an AI service provider.')}</p>", is_error_or_status=True)
|
||||
self.show_error(f"<b>{_('AI provider not configured.')}</b> <a href='http://configure-ai.com'>{_(
|
||||
'Configure AI provider')}</a>", is_critical=False)
|
||||
return
|
||||
if not self.latched_conversation_text:
|
||||
self.show_response(f"<p style='color:red'><b>{_('Error')}:</b> {_('No text is selected for this conversation.')}</p>", is_error_or_status=True)
|
||||
self.show_error(f"<b>{_('Error')}:</b> {_('No text is selected for this conversation.')}", is_critical=True)
|
||||
return
|
||||
|
||||
if not self.conversation_history:
|
||||
@ -446,16 +390,31 @@ class LLMPanel(QWidget):
|
||||
context_header += '.\n\n'
|
||||
context_header += f'I have selected the following text from this book:\n{self.latched_conversation_text}\n\n'
|
||||
self.conversation_history.append(ChatMessage(
|
||||
id=next(self.counter), query=context_header, type=ChatMessage.system, extra_data=self.latched_conversation_text))
|
||||
self.conversation_history.append(ChatMessage(
|
||||
id=next(self.counter), query=action_prompt, type=ChatMessageType.user))
|
||||
api_call_history = self.conversation_history.copy()
|
||||
query=context_header, type=ChatMessage.system, extra_data=self.latched_conversation_text))
|
||||
self.conversation_history.append(ChatMessage(query=action_prompt, type=ChatMessageType.user))
|
||||
self.result_display.setHtml(self._render_conversation_html(thinking=True))
|
||||
self.result_display.verticalScrollBar().setValue(self.result_display.verticalScrollBar().maximum())
|
||||
self.set_all_inputs_enabled(False)
|
||||
|
||||
api_call_thread = LLMAPICall(api_call_history, self.response_received)
|
||||
api_call_thread.start()
|
||||
self.current_api_call_number = next(self.counter)
|
||||
api_call = Thread(name='LLMAPICall', daemon=True, target=self.do_api_call, args=(
|
||||
self.conversation_history.copy(), self.current_api_call_number, self.ai_provider_plugin))
|
||||
api_call.start()
|
||||
|
||||
def do_api_call(
|
||||
self, conversation_history: ConversationHistory, current_api_call_number: int, ai_plugin: AIProviderPlugin
|
||||
) -> None:
|
||||
for res in ai_plugin.text_chat(conversation_history.items, conversation_history.model_used):
|
||||
self.response_received.emit(current_api_call_number, res)
|
||||
self.response_received.emit(current_api_call_number, None)
|
||||
|
||||
def on_response_from_ai(self, current_api_call_number: int, r: ChatResponse | None) -> None:
|
||||
if current_api_call_number != self.current_api_call_number:
|
||||
return
|
||||
if r is None:
|
||||
self.conversation_history.finalize_response()
|
||||
else:
|
||||
self.conversation_history.accumulator.accumulate(r)
|
||||
|
||||
def show_response(self, response_text, usage_data=None, is_error_or_status=False):
|
||||
self.last_response_text = ''
|
||||
@ -478,6 +437,16 @@ class LLMPanel(QWidget):
|
||||
self.set_all_inputs_enabled(True)
|
||||
self.custom_prompt_edit.clear()
|
||||
|
||||
def show_html(self, html: str) -> None:
|
||||
self.save_note_button.setEnabled(bool(self.last_response_text) and bool(self.latched_conversation_text))
|
||||
self.result_display.setHtml(html)
|
||||
self.result_display.verticalScrollBar().setValue(self.result_display.verticalScrollBar().maximum())
|
||||
self.set_all_inputs_enabled(True)
|
||||
self.custom_prompt_edit.clear()
|
||||
|
||||
def show_error(self, html: str, is_critical: bool = False) -> None:
|
||||
self.show_html(f'<p style="color: {"red" if is_critical else "orange"}">{html}')
|
||||
|
||||
def update_cost(self, usage_data):
|
||||
model_id = vprefs.get('llm_model_id', 'google/gemini-1.5-flash')
|
||||
prompt_tokens = usage_data.get('prompt_tokens', 0)
|
||||
@ -511,19 +480,13 @@ class LLMPanel(QWidget):
|
||||
}
|
||||
self.add_note_requested.emit(payload)
|
||||
|
||||
def _on_chat_link_clicked(self, qurl):
|
||||
url_str = qurl.toString()
|
||||
parsed_url = urlparse(url_str)
|
||||
if parsed_url.hostname == self._SAVE_ACTION_URL_SCHEME and parsed_url.path == '/save':
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
index_str = query_params.get('index', [None])[0]
|
||||
if index_str is not None:
|
||||
try:
|
||||
index = int(index_str)
|
||||
self.save_specific_note(index)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
return
|
||||
def _on_chat_link_clicked(self, qurl: QUrl):
|
||||
match qurl.host():
|
||||
case self.save_note_hostname:
|
||||
index = int(qurl.path().strip('/'))
|
||||
self.save_specific_note(index)
|
||||
case self.configure_ai_hostname:
|
||||
self.show_settings()
|
||||
|
||||
def set_all_inputs_enabled(self, enabled):
|
||||
for i in range(self.quick_actions_layout.count()):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user