diff --git a/bypy/windows/__main__.py b/bypy/windows/__main__.py index 8ae57e86ec..76785f5bc6 100644 --- a/bypy/windows/__main__.py +++ b/bypy/windows/__main__.py @@ -137,7 +137,10 @@ def freeze(env, ext_dir, incdir): # piper for x in ('espeak-ng-data',): shutil.copytree(os.path.join(PREFIX, 'share', x), os.path.join(env.share_dir, x)) - copybin(os.path.join(libdir, "onnxruntime.dll")) + for f in glob.glob(os.path.join(libdir, 'onnxruntime*.dll')): + copybin(f) + for f in glob.glob(os.path.join(libdir, 'DirectML*.dll')): + copybin(f) for f in glob.glob(os.path.join(bindir, '*.dll')): if re.search(r'(easylzma|icutest)', f.lower()) is None: diff --git a/src/calibre/utils/tts/piper.cpp b/src/calibre/utils/tts/piper.cpp index 357000a3bd..02d73dc77b 100644 --- a/src/calibre/utils/tts/piper.cpp +++ b/src/calibre/utils/tts/piper.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #ifdef _WIN32 #define ORT_DLL_IMPORT #endif @@ -35,6 +36,8 @@ (45 | CLAUSE_INTONATION_EXCLAMATION | CLAUSE_TYPE_SENTENCE) #define CLAUSE_COLON (30 | CLAUSE_INTONATION_FULL_STOP | CLAUSE_TYPE_CLAUSE) #define CLAUSE_SEMICOLON (30 | CLAUSE_INTONATION_COMMA | CLAUSE_TYPE_CLAUSE) +static const bool USE_GPU = false; +static const bool PRINT_TIMING_INFORMATION = false; typedef char32_t Phoneme; typedef int64_t PhonemeId; @@ -99,8 +102,25 @@ static std::vector available_providers; static void set_available_providers() { - if (!available_providers.empty()) return; - available_providers = Ort::GetAvailableProviders(); + static bool providers_set = false; + if (providers_set || !USE_GPU) return; + providers_set = true; + Ort::SessionOptions opts; + opts.DisableCpuMemArena(); + opts.DisableMemPattern(); + opts.DisableProfiling(); + Ort::Env ort_env{ORT_LOGGING_LEVEL_WARNING, "piper"}; + ort_env.DisableTelemetryEvents(); + + for (const std::string& s : Ort::GetAvailableProviders()) { + if (s == "CPUExecutionProvider") continue; + std::unordered_map provider_options; + try { + opts.AppendExecutionProvider(s, provider_options); + available_providers.push_back(s); + } catch (const Ort::Exception& e) { + } + } sort_providers_by_priority(available_providers, PRIORITY_ORDER); } @@ -182,6 +202,11 @@ phonemize(PyObject *self, PyObject *pytext) { return phonemes_and_terminators; } +static +long long now() { + return std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch()).count(); +} + static PyObject* set_voice(PyObject *self, PyObject *args) { PyObject *cfg; PyObject *pymp; @@ -228,13 +253,19 @@ set_voice(PyObject *self, PyObject *args) { // Load onnx model Py_BEGIN_ALLOW_THREADS; - static Ort::SessionOptions opts; + Ort::SessionOptions opts; opts.DisableCpuMemArena(); opts.DisableMemPattern(); opts.DisableProfiling(); Ort::Env ort_env{ORT_LOGGING_LEVEL_WARNING, "piper"}; ort_env.DisableTelemetryEvents(); + for (const auto& p : available_providers) { + std::unordered_map provider_options; + opts.AppendExecutionProvider(p, provider_options); + } session.reset(); + long long st; + if (PRINT_TIMING_INFORMATION) st = now(); #ifdef _WIN32 wchar_t *model_path = PyUnicode_AsWideCharString(pymp, NULL); if (!model_path) return NULL; @@ -243,6 +274,7 @@ set_voice(PyObject *self, PyObject *args) { #else session = std::make_unique(Ort::Session(ort_env, PyUnicode_AsUTF8(pymp), opts)); #endif + if (PRINT_TIMING_INFORMATION) { printf("model loading time: %f\n", (now()-st) / 1e9); fflush(stdout); } Py_END_ALLOW_THREADS; @@ -385,9 +417,12 @@ next(PyObject *self, PyObject *args) { // Infer Ort::RunOptions ro; + long long st; + if (PRINT_TIMING_INFORMATION) st = now(); output_tensors = session->Run( ro, input_names.data(), input_tensors.data(), input_tensors.size(), output_names.data(), output_names.size()); + if (PRINT_TIMING_INFORMATION) { printf("model run time: %f\n", (now()-st) / 1e9); fflush(stdout); } Py_END_ALLOW_THREADS; if ((output_tensors.size() != 1) || (!output_tensors.front().IsTensor())) {