From c8acb5d2d34f044e437de52a3299e457ec24c887 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Thu, 4 Sep 2025 10:16:49 +0530 Subject: [PATCH] More work on integration new AI backend into existing viewer code --- src/calibre/ai/open_router/backend.py | 9 +- src/calibre/gui2/viewer/llm.py | 177 ++++++++++---------------- 2 files changed, 78 insertions(+), 108 deletions(-) diff --git a/src/calibre/ai/open_router/backend.py b/src/calibre/ai/open_router/backend.py index 0ece35046d..e3929bd78e 100644 --- a/src/calibre/ai/open_router/backend.py +++ b/src/calibre/ai/open_router/backend.py @@ -13,7 +13,7 @@ from functools import lru_cache from pprint import pprint from threading import Thread from typing import Any, NamedTuple -from urllib.error import HTTPError +from urllib.error import HTTPError, URLError from urllib.request import ProxyHandler, Request, build_opener from calibre import browser, get_proxies @@ -346,7 +346,14 @@ def text_chat(messages: Iterable[ChatMessage], use_model: str = '') -> Iterator[ details = e.fp.read().decode() except Exception: details = '' + try: + error_json = json.loads(details) + details = error_json.get('error', {}).get('message', details) + except Exception: + pass yield ChatResponse(exception=e, error_details=details) + except URLError as e: + yield ChatResponse(exception=e, error_details=f'Network error: {e.reason}') except Exception as e: import traceback yield ChatResponse(exception=e, error_details=traceback.format_exc()) diff --git a/src/calibre/gui2/viewer/llm.py b/src/calibre/gui2/viewer/llm.py index 9a30f74aa0..f84daa7368 100644 --- a/src/calibre/gui2/viewer/llm.py +++ b/src/calibre/gui2/viewer/llm.py @@ -1,15 +1,11 @@ # License: GPL v3 Copyright: 2025, Amir Tehrani and Kovid Goyal -import json import textwrap from collections.abc import Iterator from functools import lru_cache, partial from itertools import count from threading import Thread from typing import NamedTuple -from urllib import request -from urllib.error import HTTPError, URLError -from urllib.parse import parse_qs, urlparse from qt.core import ( QAbstractItemView, @@ -33,14 +29,17 @@ from qt.core import ( Qt, QTabWidget, QTextBrowser, + QUrl, QVBoxLayout, QWidget, pyqtSignal, ) -from calibre.ai import AICapabilities, ChatMessage, ChatMessageType +from calibre.ai import AICapabilities, ChatMessage, ChatMessageType, ChatResponse from calibre.ai.config import ConfigureAI from calibre.ai.prefs import plugin_for_purpose +from calibre.ai.utils import StreamedResponseAccumulator +from calibre.customize import AIProviderPlugin from calibre.ebooks.metadata import authors_to_string from calibre.gui2 import Application, error_dialog from calibre.gui2.dialogs.confirm_delete import confirm @@ -48,69 +47,7 @@ from calibre.gui2.viewer.config import vprefs from calibre.gui2.viewer.highlights import HighlightColorCombo from calibre.gui2.widgets2 import Dialog from calibre.utils.icu import primary_sort_key - -# --- Backend Abstraction & Cost Data --- -API_PROVIDERS = { - 'openrouter': { - 'url': 'https://openrouter.ai/api/v1/chat/completions', - 'headers': lambda api_key: { - 'Authorization': f'Bearer {api_key}', - 'Content-Type': 'application/json', - 'HTTP-Referer': 'https://calibre-ebook.com', - 'X-Title': 'calibre' - }, - 'payload': lambda model_id, messages: { - 'model': model_id, - 'messages': messages - }, - 'parse_response': lambda r_json: ( - r_json['choices'][0]['message']['content'], - r_json.get('usage', {'prompt_tokens': 0, 'completion_tokens': 0}) - ) - } -} -# --- End Backend Abstraction --- - - -class LLMAPICall(Thread): - def __init__(self, conversation_history, signal_emitter): - super().__init__(daemon=True) - self.conversation_history = conversation_history - self.signal_emitter = signal_emitter - - def run(self): - try: - url = self.provider_config['url'] - headers = self.provider_config['headers'](self.api_key) - payload = self.provider_config['payload'](self.model_id, self.conversation_history) - - encoded_data = json.dumps(payload).encode('utf-8') - req = request.Request(url, data=encoded_data, headers=headers, method='POST') - - with request.urlopen(req, timeout=90) as response: - response_data = response.read().decode('utf-8') - response_json = json.loads(response_data) - - if 'error' in response_json: - raise Exception(response_json['error'].get('message', 'Unknown API error')) - if not response_json.get('choices'): - raise Exception('API response did not contain any choices.') - - result_text, usage_data = self.provider_config['parse_response'](response_json) - self.signal_emitter.emit(result_text, usage_data) - - except HTTPError as e: - error_body = e.read().decode('utf-8') - try: - error_json = json.loads(error_body) - msg = error_json.get('error', {}).get('message', error_body) - except json.JSONDecodeError: - msg = error_body - self.signal_emitter.emit(f"

API Error ({e.code}): {msg}

", {}) - except URLError as e: - self.signal_emitter.emit(f"

Network Error: {e.reason}

", {}) - except Exception as e: - self.signal_emitter.emit(f"

An unexpected error occurred: {e}

", {}) +from calibre.utils.short_uuid import uuid4 class Action(NamedTuple): @@ -128,12 +65,12 @@ class Action(NamedTuple): @lru_cache(2) def default_actions() -> tuple[Action, ...]: return ( - Action('summarize', _('Summarize'), 'Provide a concise summary of the following text.'), - Action('explain', _('Explain'), 'Explain the following text in simple, easy-to-understand terms.'), - Action('points', _('Key points'), 'Extract the key points from the following text as a bulleted list.'), - Action('define', _('Define'), 'Identify and define any technical or complex terms in the following text.'), - Action('grammar', _('Correct grammar'), 'Correct any grammatical errors in the following text and provide the corrected version.'), - Action('english', _('As English'), 'Translate the following text into English.'), + Action('summarize', _('Summarize'), 'Provide a concise summary of the selected text.'), + Action('explain', _('Explain'), 'Explain the selected text in simple, easy-to-understand terms.'), + Action('points', _('Key points'), 'Extract the key points from the selected text as a bulleted list.'), + Action('define', _('Define'), 'Identify and define any technical or complex terms in the selected text.'), + Action('grammar', _('Correct grammar'), 'Correct any grammatical errors in the selected text and provide the corrected version.'), + Action('english', _('As English'), 'Translate the selected text into English.'), ) @@ -153,8 +90,10 @@ def current_actions(include_disabled=False): class ConversationHistory: def __init__(self, conversation_text: str = ''): + self.accumulator = StreamedResponseAccumulator() self.items: list[ChatMessage] = [] self.conversation_text: str = conversation_text + self.model_used = '' def __iter__(self) -> Iterator[ChatMessage]: return iter(self.items) @@ -170,6 +109,7 @@ class ConversationHistory: def copy(self, upto: int | None = None) -> 'ConversationHistory': ans = ConversationHistory(self.conversation_text) + ans.model_used = self.model_used if upto is None: ans.items = list(self.items) else: @@ -223,12 +163,13 @@ def format_llm_note(conversation: ConversationHistory) -> str: class LLMPanel(QWidget): - response_received = pyqtSignal(str, dict) + response_received = pyqtSignal(int, object) add_note_requested = pyqtSignal(dict) - _SAVE_ACTION_URL_SCHEME = 'calibre-llm-action' def __init__(self, parent=None, viewer=None, lookup_widget=None): super().__init__(parent) + self.save_note_hostname = f'{uuid4().lower()}.calibre' + self.configure_ai_hostname = f'{uuid4().lower()}.calibre' self.viewer = viewer self.counter = count(start=1) self.lookup_widget = lookup_widget @@ -262,7 +203,7 @@ class LLMPanel(QWidget): self.layout.addWidget(custom_prompt_group) self.result_display = QTextBrowser(self) - self.result_display.setOpenExternalLinks(False) + self.result_display.setOpenLinks(False) self.result_display.setMinimumHeight(150) self.result_display.anchorClicked.connect(self._on_chat_link_clicked) self.layout.addWidget(self.result_display) @@ -293,7 +234,7 @@ class LLMPanel(QWidget): self.custom_prompt_button.clicked.connect(self.run_custom_prompt) self.custom_prompt_edit.returnPressed.connect(self.run_custom_prompt) - self.response_received.connect(self.show_response) + self.response_received.connect(self.on_response_from_ai, type=Qt.ConnectionType.QueuedConnection) self.settings_button.clicked.connect(self.show_settings) self.show_initial_message() @@ -326,18 +267,21 @@ class LLMPanel(QWidget): dialog.actions_updated.connect(self.rebuild_actions_ui) dialog.exec() + @property + def ai_provider_plugin(self) -> AIProviderPlugin | None: + return plugin_for_purpose(AICapabilities.text_to_text) + @property def is_ready_for_use(self) -> bool: - p = plugin_for_purpose(AICapabilities.text_to_text) + p = self.ai_provider_plugin return p is not None and p.is_ready_for_use def show_initial_message(self): self.save_note_button.setEnabled(False) if not self.is_ready_for_use: - self.show_response('

' + _( - 'Please configure an AI provider by clicking the Settings button below.'), is_error_or_status=True) + self.show_html(f'

{_("First, configure an AI provider")}') else: - self.show_response(_('Select text in the book to begin.'), is_error_or_status=True) + self.show_html('

' + _('Select text in the book to begin.')) def update_with_text(self, text, highlight_data, is_read_only_view=False): new_uuid = highlight_data.get('uuid') if highlight_data else None @@ -365,9 +309,9 @@ class LLMPanel(QWidget): self.conversation_history = ConversationHistory() if text: - self.show_response(f"{_('Selected')}:
'{text[:200]}...'", is_error_or_status=True) + self.show_html(f"{_('Selected')}:
'{text[:200]}...'") else: - self.show_response(_('Ready. Ask a follow-up question.'), is_error_or_status=True) + self.show_html(_('Ready. Ask a follow-up question.')) if self.latched_highlight_uuid: self.save_note_button.setToolTip(_("Append this response to the existing highlight's note")) @@ -413,7 +357,7 @@ class LLMPanel(QWidget): else: bgcolor = assistant_bgcolor label = _('Assistant') - save_button_href = f'http://{self._SAVE_ACTION_URL_SCHEME}/save?index={i}' + save_button_href = f'http://{self.save_note_hostname}/{i}' html_output += f''' @@ -429,11 +373,11 @@ class LLMPanel(QWidget): def start_api_call(self, action_prompt): if not self.is_ready_for_use: - self.show_response(f"

{_('AI provider not configured')} {_( - 'Click the Settings button to configure an AI service provider.')}

", is_error_or_status=True) + self.show_error(f"{_('AI provider not configured.')}{_( + 'Configure AI provider')}", is_critical=False) return if not self.latched_conversation_text: - self.show_response(f"

{_('Error')}: {_('No text is selected for this conversation.')}

", is_error_or_status=True) + self.show_error(f"{_('Error')}: {_('No text is selected for this conversation.')}", is_critical=True) return if not self.conversation_history: @@ -446,16 +390,31 @@ class LLMPanel(QWidget): context_header += '.\n\n' context_header += f'I have selected the following text from this book:\n{self.latched_conversation_text}\n\n' self.conversation_history.append(ChatMessage( - id=next(self.counter), query=context_header, type=ChatMessage.system, extra_data=self.latched_conversation_text)) - self.conversation_history.append(ChatMessage( - id=next(self.counter), query=action_prompt, type=ChatMessageType.user)) - api_call_history = self.conversation_history.copy() + query=context_header, type=ChatMessage.system, extra_data=self.latched_conversation_text)) + self.conversation_history.append(ChatMessage(query=action_prompt, type=ChatMessageType.user)) self.result_display.setHtml(self._render_conversation_html(thinking=True)) self.result_display.verticalScrollBar().setValue(self.result_display.verticalScrollBar().maximum()) self.set_all_inputs_enabled(False) - api_call_thread = LLMAPICall(api_call_history, self.response_received) - api_call_thread.start() + self.current_api_call_number = next(self.counter) + api_call = Thread(name='LLMAPICall', daemon=True, target=self.do_api_call, args=( + self.conversation_history.copy(), self.current_api_call_number, self.ai_provider_plugin)) + api_call.start() + + def do_api_call( + self, conversation_history: ConversationHistory, current_api_call_number: int, ai_plugin: AIProviderPlugin + ) -> None: + for res in ai_plugin.text_chat(conversation_history.items, conversation_history.model_used): + self.response_received.emit(current_api_call_number, res) + self.response_received.emit(current_api_call_number, None) + + def on_response_from_ai(self, current_api_call_number: int, r: ChatResponse | None) -> None: + if current_api_call_number != self.current_api_call_number: + return + if r is None: + self.conversation_history.finalize_response() + else: + self.conversation_history.accumulator.accumulate(r) def show_response(self, response_text, usage_data=None, is_error_or_status=False): self.last_response_text = '' @@ -478,6 +437,16 @@ class LLMPanel(QWidget): self.set_all_inputs_enabled(True) self.custom_prompt_edit.clear() + def show_html(self, html: str) -> None: + self.save_note_button.setEnabled(bool(self.last_response_text) and bool(self.latched_conversation_text)) + self.result_display.setHtml(html) + self.result_display.verticalScrollBar().setValue(self.result_display.verticalScrollBar().maximum()) + self.set_all_inputs_enabled(True) + self.custom_prompt_edit.clear() + + def show_error(self, html: str, is_critical: bool = False) -> None: + self.show_html(f'

{html}') + def update_cost(self, usage_data): model_id = vprefs.get('llm_model_id', 'google/gemini-1.5-flash') prompt_tokens = usage_data.get('prompt_tokens', 0) @@ -511,19 +480,13 @@ class LLMPanel(QWidget): } self.add_note_requested.emit(payload) - def _on_chat_link_clicked(self, qurl): - url_str = qurl.toString() - parsed_url = urlparse(url_str) - if parsed_url.hostname == self._SAVE_ACTION_URL_SCHEME and parsed_url.path == '/save': - query_params = parse_qs(parsed_url.query) - index_str = query_params.get('index', [None])[0] - if index_str is not None: - try: - index = int(index_str) - self.save_specific_note(index) - except (ValueError, TypeError): - pass - return + def _on_chat_link_clicked(self, qurl: QUrl): + match qurl.host(): + case self.save_note_hostname: + index = int(qurl.path().strip('/')) + self.save_specific_note(index) + case self.configure_ai_hostname: + self.show_settings() def set_all_inputs_enabled(self, enabled): for i in range(self.quick_actions_layout.count()):