diff --git a/src/calibre/ai/open_router/backend.py b/src/calibre/ai/open_router/backend.py index 8632cbdbdf..347a038150 100644 --- a/src/calibre/ai/open_router/backend.py +++ b/src/calibre/ai/open_router/backend.py @@ -2,7 +2,6 @@ # License: GPLv3 Copyright: 2025, Kovid Goyal import datetime -import http import json import os import re @@ -11,13 +10,12 @@ from collections.abc import Iterable, Iterator from functools import lru_cache from pprint import pprint from typing import Any, NamedTuple -from urllib.error import HTTPError, URLError from urllib.request import Request from calibre.ai import AICapabilities, ChatMessage, ChatMessageType, ChatResponse, NoFreeModels from calibre.ai.open_router import OpenRouterAI from calibre.ai.prefs import pref_for_provider -from calibre.ai.utils import StreamedResponseAccumulator, get_cached_resource, opener +from calibre.ai.utils import StreamedResponseAccumulator, chat_with_error_handler, get_cached_resource, read_streaming_response from calibre.constants import cache_dir from polyglot.binary import from_hex_unicode @@ -246,13 +244,7 @@ def text_chat_implementation(messages: Iterable[ChatMessage], use_model: str = ' data['reasoning']['enabled'] = False rq = chat_request(data) - def read_response(buffer: str) -> Iterator[ChatResponse]: - if not buffer.startswith('data: '): - return - buffer = buffer[6:].rstrip() - if buffer == '[DONE]': - return - data = json.loads(buffer) + for data in read_streaming_response(rq): for choice in data['choices']: d = choice['delta'] c = d.get('content') or '' @@ -267,40 +259,9 @@ def text_chat_implementation(messages: Iterable[ChatMessage], use_model: str = ' model=data.get('model') or '', has_metadata=True, ) - with opener().open(rq) as response: - if response.status != http.HTTPStatus.OK: - raise Exception(f'OpenRouter API failed with status code: {response.status} and body: {response.read().decode("utf-8", "replace")}') - buffer = '' - for raw_line in response: - line = raw_line.decode('utf-8') - if line.strip() == '': - if buffer: - yield from read_response(buffer) - buffer = '' - else: - buffer += line - yield from read_response(buffer) - def text_chat(messages: Iterable[ChatMessage], use_model: str = '') -> Iterator[ChatResponse]: - try: - yield from text_chat_implementation(messages, use_model) - except HTTPError as e: - try: - details = e.fp.read().decode() - except Exception: - details = '' - try: - error_json = json.loads(details) - details = error_json.get('error', {}).get('message', details) - except Exception: - pass - yield ChatResponse(exception=e, error_details=details) - except URLError as e: - yield ChatResponse(exception=e, error_details=f'Network error: {e.reason}') - except Exception as e: - import traceback - yield ChatResponse(exception=e, error_details=traceback.format_exc()) + yield from chat_with_error_handler(text_chat_implementation(messages, use_model)) def develop(use_model: str = ''): diff --git a/src/calibre/ai/utils.py b/src/calibre/ai/utils.py index 83a260f9fa..35003ffa77 100644 --- a/src/calibre/ai/utils.py +++ b/src/calibre/ai/utils.py @@ -2,13 +2,16 @@ # License: GPLv3 Copyright: 2025, Kovid Goyal import datetime +import http +import json import os import tempfile -from collections.abc import Iterator +from collections.abc import Iterable, Iterator from contextlib import suppress from threading import Thread from typing import Any -from urllib.request import ProxyHandler, build_opener +from urllib.error import HTTPError, URLError +from urllib.request import ProxyHandler, Request, build_opener from calibre import get_proxies from calibre.ai import ChatMessage, ChatMessageType, ChatResponse @@ -63,6 +66,55 @@ def get_cached_resource(path: str, url: str) -> bytes: return data +def _read_response(buffer: str) -> Iterator[dict[str, Any]]: + if not buffer.startswith('data: '): + return + buffer = buffer[6:].rstrip() + if buffer == '[DONE]': + return + yield json.loads(buffer) + + +def read_streaming_response(rq: Request) -> Iterator[dict[str, Any]]: + with opener().open(rq) as response: + if response.status != http.HTTPStatus.OK: + details = '' + with suppress(Exception): + details = response.read().decode('utf-8', 'replace') + raise Exception(f'Reading from AI provider failed with HTTP response status: {response.status} and body: {details}') + buffer = '' + for raw_line in response: + line = raw_line.decode('utf-8') + if line.strip() == '': + if buffer: + yield from _read_response(buffer) + buffer = '' + else: + buffer += line + yield from _read_response(buffer) + + +def chat_with_error_handler(it: Iterable[ChatResponse]) -> Iterator[ChatResponse]: + try: + yield from it + except HTTPError as e: + try: + details = e.fp.read().decode('utf-8', 'replace') + except Exception: + details = '' + try: + error_json = json.loads(details) + details = error_json.get('error', {}).get('message', details) + except Exception: + pass + yield ChatResponse(exception=e, error_details=details) + except URLError as e: + yield ChatResponse(exception=e, error_details=f'Network error: {e.reason}') + except Exception as e: + import traceback + yield ChatResponse(exception=e, error_details=traceback.format_exc()) + + class StreamedResponseAccumulator: def __init__(self):