More work on TTS embedding

This commit is contained in:
Kovid Goyal 2024-10-09 09:45:43 +05:30
parent 72d958271e
commit 55e6ef52ad
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
2 changed files with 258 additions and 96 deletions

View File

@ -171,9 +171,9 @@ def download_resources(
from calibre import browser from calibre import browser
print(title) print(title)
print(message) print(message)
br = browser()
for url, (path, name) in urls.items(): for url, (path, name) in urls.items():
print(_('Downloading {}...').format(name)) print(_('Downloading {}...').format(name))
br = browser()
data = br.open_novisit(url).read() data = br.open_novisit(url).read()
with open(path, 'wb') as f: with open(path, 'wb') as f:
f.write(data) f.write(data)

View File

@ -9,11 +9,14 @@ import sys
from collections import deque from collections import deque
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache
from itertools import count from itertools import count
from queue import Empty, Queue
from threading import Event
from time import monotonic from time import monotonic
from typing import BinaryIO, Iterable, Iterator
from qt.core import ( from qt.core import (
QApplication,
QAudio, QAudio,
QAudioFormat, QAudioFormat,
QAudioSink, QAudioSink,
@ -32,7 +35,7 @@ from qt.core import (
from calibre.constants import cache_dir, is_debugging, piper_cmdline from calibre.constants import cache_dir, is_debugging, piper_cmdline
from calibre.gui2 import error_dialog from calibre.gui2 import error_dialog
from calibre.gui2.tts.types import EngineSpecificSettings, Quality, TTSBackend, Voice, widget_parent from calibre.gui2.tts.types import TTS_EMBEDED_CONFIG, EngineSpecificSettings, Quality, TTSBackend, Voice, widget_parent
from calibre.spell.break_iterator import PARAGRAPH_SEPARATOR, split_into_sentences_for_tts from calibre.spell.break_iterator import PARAGRAPH_SEPARATOR, split_into_sentences_for_tts
from calibre.utils.localization import canonicalize_lang, get_lang from calibre.utils.localization import canonicalize_lang, get_lang
from calibre.utils.resources import get_path as P from calibre.utils.resources import get_path as P
@ -46,6 +49,97 @@ def debug(*a, **kw):
print(f'[{monotonic() - debug.first:.2f}]', *a, **kw) print(f'[{monotonic() - debug.first:.2f}]', *a, **kw)
def audio_format(audio_rate: int = 22050) -> QAudioFormat:
fmt = QAudioFormat()
fmt.setSampleFormat(QAudioFormat.SampleFormat.Int16)
fmt.setSampleRate(audio_rate)
fmt.setChannelConfig(QAudioFormat.ChannelConfig.ChannelConfigMono)
return fmt
def piper_process_metadata(model_path, config_path, s: EngineSpecificSettings, voice: Voice) -> tuple[int, list[str]]:
if not model_path:
raise Exception('Could not download voice data')
if 'metadata' not in voice.engine_data:
with open(config_path) as f:
voice.engine_data['metadata'] = json.load(f)
audio_rate = voice.engine_data['metadata']['audio']['sample_rate']
length_scale = max(0.1, 1 + -1 * s.rate) # maps -1 to 1 to 2 to 0.1
cmdline = list(piper_cmdline()) + [
'--model', model_path, '--config', config_path, '--output-raw', '--json-input',
'--sentence-silence', str(s.sentence_delay), '--length_scale', str(length_scale)]
if is_debugging():
cmdline.append('--debug')
return audio_rate, cmdline
def piper_cache_dir() -> str:
return os.path.join(cache_dir(), 'piper-voices')
def paths_for_voice(voice: Voice) -> tuple[str, str]:
fname = voice.engine_data['model_filename']
model_path = os.path.join(piper_cache_dir(), fname)
config_path = os.path.join(os.path.dirname(model_path), fname + '.json')
return model_path, config_path
def load_voice_metadata() -> tuple[dict[str, Voice], tuple[Voice, ...], dict[str, Voice]]:
d = json.loads(P('piper-voices.json', data=True))
ans = []
lang_voices_map = {}
_voice_name_map = {}
downloaded = set()
with suppress(OSError):
downloaded = set(os.listdir(piper_cache_dir()))
for bcp_code, voice_map in d['lang_map'].items():
lang, sep, country = bcp_code.partition('_')
lang = canonicalize_lang(lang) or lang
voices_for_lang = lang_voices_map.setdefault(lang, [])
for voice_name, qual_map in voice_map.items():
best_qual = voice = None
for qual, e in qual_map.items():
q = Quality.from_piper_quality(qual)
if best_qual is None or q.value < best_qual.value:
best_qual = q
mf = f'{bcp_code}-{voice_name}-{qual}.onnx'
voice = Voice(bcp_code + ':' + voice_name, lang, country, human_name=voice_name, quality=q, engine_data={
'model_url': e['model'], 'config_url': e['config'],
'model_filename': mf, 'is_downloaded': mf in downloaded,
})
if voice:
ans.append(voice)
_voice_name_map[voice.name] = voice
voices_for_lang.append(voice)
_voices = tuple(ans)
_voice_for_lang = {}
for lang, voices in lang_voices_map.items():
voices.sort(key=lambda v: v.quality.value)
_voice_for_lang[lang] = voices[0]
if lang == 'eng':
for v in voices:
if v.human_name == 'libritts':
_voice_for_lang[lang] = v
break
return _voice_name_map, _voices, _voice_for_lang
def download_voice(voice: Voice, download_even_if_exists: bool = False, parent: QObject | None = None, headless: bool = False) -> tuple[str, str]:
model_path, config_path = paths_for_voice(voice)
if os.path.exists(model_path) and os.path.exists(config_path):
if not download_even_if_exists:
return model_path, config_path
os.makedirs(os.path.dirname(model_path), exist_ok=True)
from calibre.gui2.tts.download import download_resources
ok = download_resources(_('Downloading voice for Read aloud'), _('Downloading neural network for the {} voice').format(voice.human_name), {
voice.engine_data['model_url']: (model_path, _('Neural network data')),
voice.engine_data['config_url']: (config_path, _('Neural network metadata')),
}, parent=widget_parent(parent), headless=headless,
)
voice.engine_data['is_downloaded'] = bool(ok)
return (model_path, config_path) if ok else ('', '')
@dataclass @dataclass
class Utterance: class Utterance:
id: int id: int
@ -156,13 +250,34 @@ def split_into_utterances(text: str, counter: count, lang: str = 'en'):
yield u yield u
@lru_cache(2)
def stderr_pat():
return re.compile(rb'\[piper\] \[([a-zA-Z0-9_]+?)\] (.+)')
def detect_end_of_data(data: bytes, callback):
lines = data.split(b'\n')
for line in lines[:-1]:
if m := stderr_pat().search(line):
which, payload = m.group(1), m.group(2)
if which == b'info':
debug(f'[piper-info] {payload.decode("utf-8", "replace")}')
if payload.startswith(b'Real-time factor:'):
callback(True, None)
elif which == b'error':
callback(False, payload.decode('utf-8', 'replace'))
elif which == b'debug':
debug(f'[piper-debug] {payload.decode("utf-8", "replace")}')
return lines[-1]
class Piper(TTSBackend): class Piper(TTSBackend):
engine_name: str = 'piper' engine_name: str = 'piper'
filler_char: str = PARAGRAPH_SEPARATOR filler_char: str = PARAGRAPH_SEPARATOR
_synthesis_done = pyqtSignal() _synthesis_done = pyqtSignal()
def __init__(self, engine_name: str = '', parent: QObject|None = None): def __init__(self, engine_name: str = '', parent: QObject | None = None):
super().__init__(parent) super().__init__(parent)
self._process: QProcess | None = None self._process: QProcess | None = None
self._audio_sink: QAudioSink | None = None self._audio_sink: QAudioSink | None = None
@ -179,7 +294,6 @@ class Piper(TTSBackend):
self._errors_from_piper: list[str] = [] self._errors_from_piper: list[str] = []
self._pending_stderr_data = b'' self._pending_stderr_data = b''
self._stderr_pat = re.compile(rb'\[piper\] \[([a-zA-Z0-9_]+?)\] (.+)')
self._synthesis_done.connect(self._utterance_synthesized, type=Qt.ConnectionType.QueuedConnection) self._synthesis_done.connect(self._utterance_synthesized, type=Qt.ConnectionType.QueuedConnection)
atexit.register(self.shutdown) atexit.register(self.shutdown)
@ -188,16 +302,21 @@ class Piper(TTSBackend):
self._load_voice_metadata() self._load_voice_metadata()
return {'': self._voices} return {'': self._voices}
def say(self, text: str) -> None: def _wait_for_process_to_start(self) -> bool:
if self._last_error:
return
self.stop()
if not self.process.waitForStarted(): if not self.process.waitForStarted():
cmdline = [self.process.program()] + self.process.arguments() cmdline = [self.process.program()] + self.process.arguments()
if self.process.error() is QProcess.ProcessError.TimedOut: if self.process.error() is QProcess.ProcessError.TimedOut:
self._set_error(f'Timed out waiting for piper process {cmdline} to start') self._set_error(f'Timed out waiting for piper process {cmdline} to start')
else: else:
self._set_error(f'Failed to start piper process: {cmdline}') self._set_error(f'Failed to start piper process: {cmdline}')
return False
return True
def say(self, text: str) -> None:
if self._last_error:
return
self.stop()
if not self._wait_for_process_to_start():
return return
lang = 'en' lang = 'en'
if self._current_voice and self._current_voice.language_code: if self._current_voice and self._current_voice.language_code:
@ -265,17 +384,10 @@ class Piper(TTSBackend):
try: try:
self._load_voice_metadata() self._load_voice_metadata()
s = EngineSpecificSettings.create_from_config(self.engine_name) s = EngineSpecificSettings.create_from_config(self.engine_name)
length_scale = max(0.1, 1 + -1 * s.rate) # maps -1 to 1 to 2 to 0.1
voice = self._voice_name_map.get(s.voice_name) or self._default_voice voice = self._voice_name_map.get(s.voice_name) or self._default_voice
model_path, config_path = self._ensure_voice_is_downloaded(voice) model_path, config_path = self._ensure_voice_is_downloaded(voice)
except AttributeError as e: except AttributeError as e:
raise Exception(str(e)) from e raise Exception(str(e)) from e
if not model_path:
raise Exception('Could not download voice data')
if 'metadata' not in voice.engine_data:
with open(config_path) as f:
voice.engine_data['metadata'] = json.load(f)
audio_rate = voice.engine_data['metadata']['audio']['sample_rate']
self._current_voice = voice self._current_voice = voice
self._utterances_being_spoken.clear() self._utterances_being_spoken.clear()
self._utterances_being_synthesized.clear() self._utterances_being_synthesized.clear()
@ -284,11 +396,7 @@ class Piper(TTSBackend):
self._pending_stderr_data = b'' self._pending_stderr_data = b''
self._set_state(QTextToSpeech.State.Ready) self._set_state(QTextToSpeech.State.Ready)
cmdline = list(piper_cmdline()) + [ audio_rate, cmdline = piper_process_metadata(model_path, config_path, s, voice)
'--model', model_path, '--config', config_path, '--output-raw', '--json-input',
'--sentence-silence', str(s.sentence_delay), '--length_scale', str(length_scale)]
if is_debugging():
cmdline.append('--debug')
self._process.setProgram(cmdline[0]) self._process.setProgram(cmdline[0])
self._process.setArguments(cmdline[1:]) self._process.setArguments(cmdline[1:])
debug('Running piper:', cmdline) debug('Running piper:', cmdline)
@ -296,10 +404,7 @@ class Piper(TTSBackend):
self._process.readyReadStandardOutput.connect(self.piper_stdout_available) self._process.readyReadStandardOutput.connect(self.piper_stdout_available)
self._process.bytesWritten.connect(self.bytes_written) self._process.bytesWritten.connect(self.bytes_written)
self._process.stateChanged.connect(self._update_status) self._process.stateChanged.connect(self._update_status)
fmt = QAudioFormat() fmt = audio_format(audio_rate)
fmt.setSampleFormat(QAudioFormat.SampleFormat.Int16)
fmt.setSampleRate(audio_rate)
fmt.setChannelConfig(QAudioFormat.ChannelConfig.ChannelConfigMono)
dev = None dev = None
if s.audio_device_id: if s.audio_device_id:
for q in QMediaDevices.audioOutputs(): for q in QMediaDevices.audioOutputs():
@ -332,20 +437,14 @@ class Piper(TTSBackend):
def piper_stderr_available(self) -> None: def piper_stderr_available(self) -> None:
if self._process is not None: if self._process is not None:
def callback(ok, payload):
if ok:
if self._utterances_being_synthesized:
self._synthesis_done.emit()
else:
self._errors_from_piper.append(payload.decode('utf-8', 'replace'))
data = self._pending_stderr_data + bytes(self._process.readAllStandardError()) data = self._pending_stderr_data + bytes(self._process.readAllStandardError())
lines = data.split(b'\n') self._pending_stderr_data = detect_end_of_data(data, callback)
for line in lines[:-1]:
if m := self._stderr_pat.search(line):
which, payload = m.group(1), m.group(2)
if which == b'info':
debug(f'[piper-info] {payload.decode("utf-8", "replace")}')
if payload.startswith(b'Real-time factor:') and self._utterances_being_synthesized:
self._synthesis_done.emit()
elif which == b'error':
self._errors_from_piper.append(payload.decode('utf-8', 'replace'))
elif which == b'debug':
debug(f'[piper-debug] {payload.decode("utf-8", "replace")}')
self._pending_stderr_data = lines[-1]
def _utterance_synthesized(self): def _utterance_synthesized(self):
self.piper_stdout_available() # just in case self.piper_stdout_available() # just in case
@ -402,42 +501,7 @@ class Piper(TTSBackend):
def _load_voice_metadata(self) -> None: def _load_voice_metadata(self) -> None:
if self._voices is not None: if self._voices is not None:
return return
d = json.loads(P('piper-voices.json', data=True)) self._voice_name_map, self._voices, self._voice_for_lang = load_voice_metadata()
ans = []
lang_voices_map = {}
self._voice_name_map = {}
downloaded = set()
with suppress(OSError):
downloaded = set(os.listdir(self.cache_dir))
for bcp_code, voice_map in d['lang_map'].items():
lang, sep, country = bcp_code.partition('_')
lang = canonicalize_lang(lang) or lang
voices_for_lang = lang_voices_map.setdefault(lang, [])
for voice_name, qual_map in voice_map.items():
best_qual = voice = None
for qual, e in qual_map.items():
q = Quality.from_piper_quality(qual)
if best_qual is None or q.value < best_qual.value:
best_qual = q
mf = f'{bcp_code}-{voice_name}-{qual}.onnx'
voice = Voice(bcp_code + ':' + voice_name, lang, country, human_name=voice_name, quality=q, engine_data={
'model_url': e['model'], 'config_url': e['config'],
'model_filename': mf, 'is_downloaded': mf in downloaded,
})
if voice:
ans.append(voice)
self._voice_name_map[voice.name] = voice
voices_for_lang.append(voice)
self._voices = tuple(ans)
self._voice_for_lang = {}
for lang, voices in lang_voices_map.items():
voices.sort(key=lambda v: v.quality.value)
self._voice_for_lang[lang] = voices[0]
if lang == 'eng':
for v in voices:
if v.human_name == 'libritts':
self._voice_for_lang[lang] = v
break
@property @property
def _default_voice(self) -> Voice: def _default_voice(self) -> Voice:
@ -448,18 +512,12 @@ class Piper(TTSBackend):
@property @property
def cache_dir(self) -> str: def cache_dir(self) -> str:
return os.path.join(cache_dir(), 'piper-voices') return piper_cache_dir()
def _paths_for_voice(self, voice: Voice) -> tuple[str, str]:
fname = voice.engine_data['model_filename']
model_path = os.path.join(self.cache_dir, fname)
config_path = os.path.join(os.path.dirname(model_path), fname + '.json')
return model_path, config_path
def is_voice_downloaded(self, v: Voice) -> bool: def is_voice_downloaded(self, v: Voice) -> bool:
if not v or not v.name: if not v or not v.name:
v = self._default_voice v = self._default_voice
for path in self._paths_for_voice(v): for path in paths_for_voice(v):
if not os.path.exists(path): if not os.path.exists(path):
return False return False
return True return True
@ -467,25 +525,13 @@ class Piper(TTSBackend):
def delete_voice(self, v: Voice) -> None: def delete_voice(self, v: Voice) -> None:
if not v.name: if not v.name:
v = self._default_voice v = self._default_voice
for path in self._paths_for_voice(v): for path in paths_for_voice(v):
with suppress(FileNotFoundError): with suppress(FileNotFoundError):
os.remove(path) os.remove(path)
v.engine_data['is_downloaded'] = False v.engine_data['is_downloaded'] = False
def _download_voice(self, voice: Voice, download_even_if_exists: bool = False) -> tuple[str, str]: def _download_voice(self, voice: Voice, download_even_if_exists: bool = False) -> tuple[str, str]:
model_path, config_path = self._paths_for_voice(voice) return download_voice(voice, download_even_if_exists, parent=self, headless=False)
if os.path.exists(model_path) and os.path.exists(config_path):
if not download_even_if_exists:
return model_path, config_path
os.makedirs(os.path.dirname(model_path), exist_ok=True)
from calibre.gui2.tts.download import download_resources
ok = download_resources(_('Downloading voice for Read aloud'), _('Downloading neural network for the {} voice').format(voice.human_name), {
voice.engine_data['model_url']: (model_path, _('Neural network data')),
voice.engine_data['config_url']: (config_path, _('Neural network metadata')),
}, parent=widget_parent(self), headless=getattr(QApplication.instance(), 'headless', False)
)
voice.engine_data['is_downloaded'] = bool(ok)
return (model_path, config_path) if ok else ('', '')
def download_voice(self, v: Voice) -> None: def download_voice(self, v: Voice) -> None:
if not v.name: if not v.name:
@ -511,6 +557,122 @@ class Piper(TTSBackend):
return True return True
class PiperEmbedded:
def __init__(self):
self._embedded_settings = EngineSpecificSettings.create_from_config(self.engine_name, TTS_EMBEDED_CONFIG)
self._voice_name_map, self._voices, self._voice_for_lang = load_voice_metadata()
lang = get_lang()
lang = canonicalize_lang(lang) or lang
self._default_voice = self._voice_for_lang.get(lang) or self._voice_for_lang['eng']
self._current_voice = self._process = self._process_shutdown_event = None
self._current_audio_format = QAudioFormat()
def resolve_voice(self, lang: str, voice_name: str) -> Voice:
from calibre.utils.localization import canonicalize_lang, get_lang
lang = canonicalize_lang(lang or get_lang() or 'en')
if voice_name and voice_name in self._voice_name_map:
voice = self._voice_name_map[voice_name]
elif (voice_name := self._embedded_settings.preferred_voices.get(lang, '')) and voice_name in self._voice_name_map:
voice = self._voice_name_map[voice_name]
else:
voice = self._voice_for_lang.get(lang) or self._default_voice
return voice
def text_to_raw_audio_data(
self, texts: Iterable[str], lang: str = '', voice_name: str = '', format: QAudioFormat | None = None, timeout: float = 10.,
) -> Iterator[bytes]:
if format is None:
format = audio_format()
voice = self.resolve_voice(lang, voice_name)
if voice is not self._current_voice:
self._current_voice = voice
self.shutdown()
self.ensure_process_started()
piper_done, errors_from_piper = [], []
last_output_at = monotonic()
def callback(ok, payload):
nonlocal last_output_at
if ok:
piper_done.append(True)
last_output_at = monotonic()
else:
errors_from_piper.append(payload.decode('utf-8', 'replace'))
for text in texts:
if not text:
yield b''
self._process.stdin.write(text.encode('utf-8', 'replace'))
stderr_data = b''
buf, piper_done, errors_from_piper = [], [], []
last_output_at = monotonic()
while not piper_done:
try:
is_stdout, exception, data = self._from_process_queue.get(True, 1.0)
except Empty:
if self._process.poll() is not None:
raise Exception(f'piper process died with error code: {self._process.poll()} and stderr: {"".join(errors_from_piper)}')
if monotonic() - last_output_at > timeout:
raise TimeoutError(f'piper process produced no output for {timeout} seconds. stderr: {"".join(errors_from_piper)}')
else:
if exception is not None:
raise exception
if is_stdout:
buf.append(data)
else:
stderr_data = detect_end_of_data(stderr_data + data, callback)
needs_conversion = format != self._current_audio_format
needs_conversion
def ensure_voices_downloaded(self, specs: Iterable[tuple[str, str]], parent: QObject = None) -> None:
for lang, voice_name in specs:
voice = self.resolve_voice(lang, voice_name)
download_voice(voice, parent=parent, headless=parent is None)
def shutdown(self):
if self._process is not None:
self._process_shutdown_event.set()
self._to_process_queue.put(None)
self._process_shutdown_event = None
self._process.stdin.close()
self._process.kill()
self._process.wait()
self._process.stdout.close()
self._process.stderr.close()
self._process = None
self._stdout_reader.join()
self._stderr_reader.join()
__del__ = shutdown
def ensure_process_started(self):
if self._process is not None:
return
model_path, config_path = download_voice(self._current_voice, headless=True)
audio_rate, cmdline = piper_process_metadata(model_path, config_path, self._embedded_settings, self._current_voice)
self._current_audio_format = audio_format(audio_rate)
import subprocess
from threading import Thread
self._process_shutdown_event = Event()
self._stdout_reader = Thread(target=self.reader, args=(self._process_shutdown_event, self._process.stdout, True), daemon=True)
self._stderr_reader = Thread(target=self.reader, args=(self._process_shutdown_event, self._process.stderr, False), daemon=True)
self._from_process_queue = Queue()
self._process = subprocess.Popen(cmdline, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
self._stdout_reader.start()
self._stderr_reader.start()
def reader(self, shutdown_event: Event, pipe: BinaryIO, is_stdout: bool = False) -> None:
while not shutdown_event.is_set():
try:
data = pipe.read()
except Exception as e:
if not shutdown_event.is_set():
self._from_process_queue.put((is_stdout, e, b''))
break
else:
self._from_process_queue.put((is_stdout, None, data))
def develop(): # {{{ def develop(): # {{{
from qt.core import QSocketNotifier from qt.core import QSocketNotifier