TTS config: Add a button to download/delete voices for the piper backend

This commit is contained in:
Kovid Goyal 2024-09-07 09:13:01 +05:30
parent ef104102ce
commit 5657351d46
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
3 changed files with 124 additions and 9 deletions

View File

@ -1,8 +1,24 @@
#!/usr/bin/env python
# License: GPLv3 Copyright: 2024, Kovid Goyal <kovid at kovidgoyal.net>
from qt.core import QCheckBox, QFormLayout, QLabel, QLocale, QMediaDevices, QSize, QSlider, Qt, QTreeWidget, QTreeWidgetItem, QVBoxLayout, QWidget, pyqtSignal
from qt.core import (
QCheckBox,
QFormLayout,
QHBoxLayout,
QIcon,
QLabel,
QLocale,
QMediaDevices,
QPushButton,
QSize,
QSlider,
Qt,
QTreeWidget,
QTreeWidgetItem,
QVBoxLayout,
QWidget,
pyqtSignal,
)
from calibre.gui2.tts.types import (
AudioDeviceId,
@ -113,10 +129,13 @@ class Volume(QWidget):
class Voices(QTreeWidget):
voice_changed = pyqtSignal()
def __init__(self, parent=None):
super().__init__(parent)
self.setHeaderHidden(True)
self.system_default_voice = Voice()
self.currentItemChanged.connect(self.voice_changed)
def sizeHint(self) -> QSize:
return QSize(400, 500)
@ -154,14 +173,23 @@ class Voices(QTreeWidget):
@property
def val(self) -> str:
voice = self.currentItem().data(0, Qt.ItemDataRole.UserRole)
voice = self.current_voice
return voice.name if voice else ''
@property
def current_voice(self) -> Voice | None:
ci = self.currentItem()
if ci is not None:
return ci.data(0, Qt.ItemDataRole.UserRole)
class EngineSpecificConfig(QWidget):
voice_changed = pyqtSignal()
def __init__(self, parent):
super().__init__(parent)
self.engine_name = ''
self.l = l = QFormLayout(self)
devs = QMediaDevices.audioOutputs()
dad = QMediaDevices.defaultAudioOutput()
@ -183,6 +211,7 @@ class EngineSpecificConfig(QWidget):
self.audio_device = ad = QComboBox(self)
l.addRow(_('Output a&udio to:'), ad)
self.voices = v = Voices(self)
v.voice_changed.connect(self.voice_changed)
la = QLabel(_('V&oices:'))
la.setBuddy(v)
l.addRow(la)
@ -241,6 +270,7 @@ class EngineSpecificConfig(QWidget):
else:
self.layout().setRowVisible(self.audio_device, False)
self.rebuild_voices()
return metadata
def rebuild_voices(self):
try:
@ -269,6 +299,29 @@ class EngineSpecificConfig(QWidget):
break
return ans
def voice_action(self):
v = self.voices.current_voice
if v is None:
return
metadata = available_engines()[self.engine_name]
if not metadata.has_managed_voices:
return
tts = create_tts_backend(self.engine_name)
if tts.is_voice_downloaded(v):
tts.delete_voice(v)
else:
tts.download_voice(v)
def current_voice_is_downloaded(self) -> bool:
v = self.voices.current_voice
if v is None:
return False
metadata = available_engines()[self.engine_name]
if not metadata.has_managed_voices:
return False
tts = create_tts_backend(self.engine_name)
return tts.is_voice_downloaded(v)
class ConfigDialog(Dialog):
@ -279,12 +332,35 @@ class ConfigDialog(Dialog):
self.l = l = QVBoxLayout(self)
self.engine_choice = ec = EngineChoice(self)
self.engine_specific_config = esc = EngineSpecificConfig(self)
ec.changed.connect(esc.set_engine)
ec.changed.connect(self.set_engine)
esc.voice_changed.connect(self.update_voice_button)
l.addWidget(ec)
l.addWidget(esc)
l.addWidget(self.bb)
self.voice_button = b = QPushButton(self)
b.clicked.connect(self.voice_action)
h = QHBoxLayout()
l.addLayout(h)
h.addWidget(b), h.addStretch(10), h.addWidget(self.bb)
self.initial_engine_choice = ec.value
esc.set_engine(self.initial_engine_choice)
self.set_engine(self.initial_engine_choice)
def set_engine(self, engine_name: str) -> None:
metadata = self.engine_specific_config.set_engine(engine_name)
self.voice_button.setVisible(metadata.has_managed_voices)
self.update_voice_button()
def update_voice_button(self):
b = self.voice_button
if self.engine_specific_config.current_voice_is_downloaded():
b.setIcon(QIcon.ic('trash.png'))
b.setText(_('Remove downloaded voice'))
else:
b.setIcon(QIcon.ic('download-metadata.png'))
b.setText(_('Download voice'))
def voice_action(self):
self.engine_specific_config.voice_action()
self.update_voice_button()
@property
def engine_changed(self) -> bool:

View File

@ -7,6 +7,7 @@ import os
import re
import sys
from collections import deque
from contextlib import suppress
from dataclasses import dataclass
from itertools import count
from time import monotonic
@ -441,12 +442,32 @@ class Piper(TTSBackend):
lang = canonicalize_lang(lang) or lang
return self._voice_for_lang.get(lang) or self._voice_for_lang['eng']
def _ensure_voice_is_downloaded(self, voice: Voice) -> tuple[str, str]:
def _paths_for_voice(self, voice: Voice) -> tuple[str, str]:
fname = voice.engine_data['model_filename']
model_path = os.path.join(cache_dir(), 'piper-voices', fname)
config_path = os.path.join(os.path.dirname(model_path), fname + '.json')
return model_path, config_path
def is_voice_downloaded(self, v: Voice) -> bool:
if not v.name:
v = self._default_voice
for path in self._paths_for_voice(v):
if not os.path.exists(path):
return False
return True
def delete_voice(self, v: Voice) -> None:
if not v.name:
v = self._default_voice
for path in self._paths_for_voice(v):
with suppress(FileNotFoundError):
os.remove(path)
def _download_voice(self, voice: Voice, download_even_if_exists: bool = False) -> tuple[str, str]:
model_path, config_path = self._paths_for_voice(voice)
if os.path.exists(model_path) and os.path.exists(config_path):
return model_path, config_path
if not download_even_if_exists:
return model_path, config_path
os.makedirs(os.path.dirname(model_path), exist_ok=True)
from calibre.gui2.tts.download import DownloadResources
d = DownloadResources(_('Downloading voice for Read aloud'), _('Downloading neural network for the {} voice').format(voice.human_name), {
@ -457,6 +478,14 @@ class Piper(TTSBackend):
return model_path, config_path
return '', ''
def download_voice(self, v: Voice) -> None:
if not v.name:
v = self._default_voice
self._download_voice(v, download_even_if_exists=True)
def _ensure_voice_is_downloaded(self, voice: Voice) -> tuple[str, str]:
return self._download_voice(voice)
def validate_settings(self, s: EngineSpecificSettings, parent: QWidget | None) -> bool:
self._load_voice_metadata()
voice = self._voice_name_map.get(s.voice_name) or self._default_voice

View File

@ -38,6 +38,7 @@ class EngineMetadata(NamedTuple):
can_change_pitch: bool = True
can_change_volume: bool = True
voices_have_quality_metadata: bool = False
has_managed_voices: bool = False
class Quality(Enum):
@ -218,7 +219,7 @@ def available_engines() -> dict[str, EngineMetadata]:
ans['piper'] = EngineMetadata('piper', _('The Piper Neural Speech Engine'), _(
'The "piper" engine can track the currently spoken sentence on screen. It uses a neural network '
'for natural sounding voices. The neural network is run locally on your computer, it is fairly resource intensive to run.'
), TrackingCapability.Sentence, can_change_pitch=False, voices_have_quality_metadata=True)
), TrackingCapability.Sentence, can_change_pitch=False, voices_have_quality_metadata=True, has_managed_voices=True)
if islinux:
from speechd.paths import SPD_SPAWN_CMD
cmd = os.getenv("SPEECHD_CMD", SPD_SPAWN_CMD)
@ -281,6 +282,15 @@ class TTSBackend(QObject):
def validate_settings(self, s: EngineSpecificSettings, parent: QWidget | None) -> bool:
return True
def is_voice_downloaded(self, v: Voice) -> bool:
return True
def delete_voice(self, v: Voice) -> None:
pass
def download_voice(self, v: Voice) -> None:
pass
engine_instances: dict[str, TTSBackend] = {}