TTS config: Add a button to download/delete voices for the piper backend

This commit is contained in:
Kovid Goyal 2024-09-07 09:13:01 +05:30
parent ef104102ce
commit 5657351d46
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
3 changed files with 124 additions and 9 deletions

View File

@ -1,8 +1,24 @@
#!/usr/bin/env python #!/usr/bin/env python
# License: GPLv3 Copyright: 2024, Kovid Goyal <kovid at kovidgoyal.net> # License: GPLv3 Copyright: 2024, Kovid Goyal <kovid at kovidgoyal.net>
from qt.core import (
from qt.core import QCheckBox, QFormLayout, QLabel, QLocale, QMediaDevices, QSize, QSlider, Qt, QTreeWidget, QTreeWidgetItem, QVBoxLayout, QWidget, pyqtSignal QCheckBox,
QFormLayout,
QHBoxLayout,
QIcon,
QLabel,
QLocale,
QMediaDevices,
QPushButton,
QSize,
QSlider,
Qt,
QTreeWidget,
QTreeWidgetItem,
QVBoxLayout,
QWidget,
pyqtSignal,
)
from calibre.gui2.tts.types import ( from calibre.gui2.tts.types import (
AudioDeviceId, AudioDeviceId,
@ -113,10 +129,13 @@ class Volume(QWidget):
class Voices(QTreeWidget): class Voices(QTreeWidget):
voice_changed = pyqtSignal()
def __init__(self, parent=None): def __init__(self, parent=None):
super().__init__(parent) super().__init__(parent)
self.setHeaderHidden(True) self.setHeaderHidden(True)
self.system_default_voice = Voice() self.system_default_voice = Voice()
self.currentItemChanged.connect(self.voice_changed)
def sizeHint(self) -> QSize: def sizeHint(self) -> QSize:
return QSize(400, 500) return QSize(400, 500)
@ -154,14 +173,23 @@ class Voices(QTreeWidget):
@property @property
def val(self) -> str: def val(self) -> str:
voice = self.currentItem().data(0, Qt.ItemDataRole.UserRole) voice = self.current_voice
return voice.name if voice else '' return voice.name if voice else ''
@property
def current_voice(self) -> Voice | None:
ci = self.currentItem()
if ci is not None:
return ci.data(0, Qt.ItemDataRole.UserRole)
class EngineSpecificConfig(QWidget): class EngineSpecificConfig(QWidget):
voice_changed = pyqtSignal()
def __init__(self, parent): def __init__(self, parent):
super().__init__(parent) super().__init__(parent)
self.engine_name = ''
self.l = l = QFormLayout(self) self.l = l = QFormLayout(self)
devs = QMediaDevices.audioOutputs() devs = QMediaDevices.audioOutputs()
dad = QMediaDevices.defaultAudioOutput() dad = QMediaDevices.defaultAudioOutput()
@ -183,6 +211,7 @@ class EngineSpecificConfig(QWidget):
self.audio_device = ad = QComboBox(self) self.audio_device = ad = QComboBox(self)
l.addRow(_('Output a&udio to:'), ad) l.addRow(_('Output a&udio to:'), ad)
self.voices = v = Voices(self) self.voices = v = Voices(self)
v.voice_changed.connect(self.voice_changed)
la = QLabel(_('V&oices:')) la = QLabel(_('V&oices:'))
la.setBuddy(v) la.setBuddy(v)
l.addRow(la) l.addRow(la)
@ -241,6 +270,7 @@ class EngineSpecificConfig(QWidget):
else: else:
self.layout().setRowVisible(self.audio_device, False) self.layout().setRowVisible(self.audio_device, False)
self.rebuild_voices() self.rebuild_voices()
return metadata
def rebuild_voices(self): def rebuild_voices(self):
try: try:
@ -269,6 +299,29 @@ class EngineSpecificConfig(QWidget):
break break
return ans return ans
def voice_action(self):
v = self.voices.current_voice
if v is None:
return
metadata = available_engines()[self.engine_name]
if not metadata.has_managed_voices:
return
tts = create_tts_backend(self.engine_name)
if tts.is_voice_downloaded(v):
tts.delete_voice(v)
else:
tts.download_voice(v)
def current_voice_is_downloaded(self) -> bool:
v = self.voices.current_voice
if v is None:
return False
metadata = available_engines()[self.engine_name]
if not metadata.has_managed_voices:
return False
tts = create_tts_backend(self.engine_name)
return tts.is_voice_downloaded(v)
class ConfigDialog(Dialog): class ConfigDialog(Dialog):
@ -279,12 +332,35 @@ class ConfigDialog(Dialog):
self.l = l = QVBoxLayout(self) self.l = l = QVBoxLayout(self)
self.engine_choice = ec = EngineChoice(self) self.engine_choice = ec = EngineChoice(self)
self.engine_specific_config = esc = EngineSpecificConfig(self) self.engine_specific_config = esc = EngineSpecificConfig(self)
ec.changed.connect(esc.set_engine) ec.changed.connect(self.set_engine)
esc.voice_changed.connect(self.update_voice_button)
l.addWidget(ec) l.addWidget(ec)
l.addWidget(esc) l.addWidget(esc)
l.addWidget(self.bb) self.voice_button = b = QPushButton(self)
b.clicked.connect(self.voice_action)
h = QHBoxLayout()
l.addLayout(h)
h.addWidget(b), h.addStretch(10), h.addWidget(self.bb)
self.initial_engine_choice = ec.value self.initial_engine_choice = ec.value
esc.set_engine(self.initial_engine_choice) self.set_engine(self.initial_engine_choice)
def set_engine(self, engine_name: str) -> None:
metadata = self.engine_specific_config.set_engine(engine_name)
self.voice_button.setVisible(metadata.has_managed_voices)
self.update_voice_button()
def update_voice_button(self):
b = self.voice_button
if self.engine_specific_config.current_voice_is_downloaded():
b.setIcon(QIcon.ic('trash.png'))
b.setText(_('Remove downloaded voice'))
else:
b.setIcon(QIcon.ic('download-metadata.png'))
b.setText(_('Download voice'))
def voice_action(self):
self.engine_specific_config.voice_action()
self.update_voice_button()
@property @property
def engine_changed(self) -> bool: def engine_changed(self) -> bool:

View File

@ -7,6 +7,7 @@ import os
import re import re
import sys import sys
from collections import deque from collections import deque
from contextlib import suppress
from dataclasses import dataclass from dataclasses import dataclass
from itertools import count from itertools import count
from time import monotonic from time import monotonic
@ -441,12 +442,32 @@ class Piper(TTSBackend):
lang = canonicalize_lang(lang) or lang lang = canonicalize_lang(lang) or lang
return self._voice_for_lang.get(lang) or self._voice_for_lang['eng'] return self._voice_for_lang.get(lang) or self._voice_for_lang['eng']
def _ensure_voice_is_downloaded(self, voice: Voice) -> tuple[str, str]: def _paths_for_voice(self, voice: Voice) -> tuple[str, str]:
fname = voice.engine_data['model_filename'] fname = voice.engine_data['model_filename']
model_path = os.path.join(cache_dir(), 'piper-voices', fname) model_path = os.path.join(cache_dir(), 'piper-voices', fname)
config_path = os.path.join(os.path.dirname(model_path), fname + '.json') config_path = os.path.join(os.path.dirname(model_path), fname + '.json')
return model_path, config_path
def is_voice_downloaded(self, v: Voice) -> bool:
if not v.name:
v = self._default_voice
for path in self._paths_for_voice(v):
if not os.path.exists(path):
return False
return True
def delete_voice(self, v: Voice) -> None:
if not v.name:
v = self._default_voice
for path in self._paths_for_voice(v):
with suppress(FileNotFoundError):
os.remove(path)
def _download_voice(self, voice: Voice, download_even_if_exists: bool = False) -> tuple[str, str]:
model_path, config_path = self._paths_for_voice(voice)
if os.path.exists(model_path) and os.path.exists(config_path): if os.path.exists(model_path) and os.path.exists(config_path):
return model_path, config_path if not download_even_if_exists:
return model_path, config_path
os.makedirs(os.path.dirname(model_path), exist_ok=True) os.makedirs(os.path.dirname(model_path), exist_ok=True)
from calibre.gui2.tts.download import DownloadResources from calibre.gui2.tts.download import DownloadResources
d = DownloadResources(_('Downloading voice for Read aloud'), _('Downloading neural network for the {} voice').format(voice.human_name), { d = DownloadResources(_('Downloading voice for Read aloud'), _('Downloading neural network for the {} voice').format(voice.human_name), {
@ -457,6 +478,14 @@ class Piper(TTSBackend):
return model_path, config_path return model_path, config_path
return '', '' return '', ''
def download_voice(self, v: Voice) -> None:
if not v.name:
v = self._default_voice
self._download_voice(v, download_even_if_exists=True)
def _ensure_voice_is_downloaded(self, voice: Voice) -> tuple[str, str]:
return self._download_voice(voice)
def validate_settings(self, s: EngineSpecificSettings, parent: QWidget | None) -> bool: def validate_settings(self, s: EngineSpecificSettings, parent: QWidget | None) -> bool:
self._load_voice_metadata() self._load_voice_metadata()
voice = self._voice_name_map.get(s.voice_name) or self._default_voice voice = self._voice_name_map.get(s.voice_name) or self._default_voice

View File

@ -38,6 +38,7 @@ class EngineMetadata(NamedTuple):
can_change_pitch: bool = True can_change_pitch: bool = True
can_change_volume: bool = True can_change_volume: bool = True
voices_have_quality_metadata: bool = False voices_have_quality_metadata: bool = False
has_managed_voices: bool = False
class Quality(Enum): class Quality(Enum):
@ -218,7 +219,7 @@ def available_engines() -> dict[str, EngineMetadata]:
ans['piper'] = EngineMetadata('piper', _('The Piper Neural Speech Engine'), _( ans['piper'] = EngineMetadata('piper', _('The Piper Neural Speech Engine'), _(
'The "piper" engine can track the currently spoken sentence on screen. It uses a neural network ' 'The "piper" engine can track the currently spoken sentence on screen. It uses a neural network '
'for natural sounding voices. The neural network is run locally on your computer, it is fairly resource intensive to run.' 'for natural sounding voices. The neural network is run locally on your computer, it is fairly resource intensive to run.'
), TrackingCapability.Sentence, can_change_pitch=False, voices_have_quality_metadata=True) ), TrackingCapability.Sentence, can_change_pitch=False, voices_have_quality_metadata=True, has_managed_voices=True)
if islinux: if islinux:
from speechd.paths import SPD_SPAWN_CMD from speechd.paths import SPD_SPAWN_CMD
cmd = os.getenv("SPEECHD_CMD", SPD_SPAWN_CMD) cmd = os.getenv("SPEECHD_CMD", SPD_SPAWN_CMD)
@ -281,6 +282,15 @@ class TTSBackend(QObject):
def validate_settings(self, s: EngineSpecificSettings, parent: QWidget | None) -> bool: def validate_settings(self, s: EngineSpecificSettings, parent: QWidget | None) -> bool:
return True return True
def is_voice_downloaded(self, v: Voice) -> bool:
return True
def delete_voice(self, v: Voice) -> None:
pass
def download_voice(self, v: Voice) -> None:
pass
engine_instances: dict[str, TTSBackend] = {} engine_instances: dict[str, TTSBackend] = {}