More work on using GPU to run inferencing

This commit is contained in:
Kovid Goyal 2025-07-30 14:27:00 +05:30
parent c7907e2081
commit fa582d8f26
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
2 changed files with 42 additions and 4 deletions

View File

@ -137,7 +137,10 @@ def freeze(env, ext_dir, incdir):
# piper # piper
for x in ('espeak-ng-data',): for x in ('espeak-ng-data',):
shutil.copytree(os.path.join(PREFIX, 'share', x), os.path.join(env.share_dir, x)) shutil.copytree(os.path.join(PREFIX, 'share', x), os.path.join(env.share_dir, x))
copybin(os.path.join(libdir, "onnxruntime.dll")) for f in glob.glob(os.path.join(libdir, 'onnxruntime*.dll')):
copybin(f)
for f in glob.glob(os.path.join(libdir, 'DirectML*.dll')):
copybin(f)
for f in glob.glob(os.path.join(bindir, '*.dll')): for f in glob.glob(os.path.join(bindir, '*.dll')):
if re.search(r'(easylzma|icutest)', f.lower()) is None: if re.search(r'(easylzma|icutest)', f.lower()) is None:

View File

@ -15,6 +15,7 @@
#include <cstdint> #include <cstdint>
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include <chrono>
#ifdef _WIN32 #ifdef _WIN32
#define ORT_DLL_IMPORT #define ORT_DLL_IMPORT
#endif #endif
@ -35,6 +36,8 @@
(45 | CLAUSE_INTONATION_EXCLAMATION | CLAUSE_TYPE_SENTENCE) (45 | CLAUSE_INTONATION_EXCLAMATION | CLAUSE_TYPE_SENTENCE)
#define CLAUSE_COLON (30 | CLAUSE_INTONATION_FULL_STOP | CLAUSE_TYPE_CLAUSE) #define CLAUSE_COLON (30 | CLAUSE_INTONATION_FULL_STOP | CLAUSE_TYPE_CLAUSE)
#define CLAUSE_SEMICOLON (30 | CLAUSE_INTONATION_COMMA | CLAUSE_TYPE_CLAUSE) #define CLAUSE_SEMICOLON (30 | CLAUSE_INTONATION_COMMA | CLAUSE_TYPE_CLAUSE)
static const bool USE_GPU = false;
static const bool PRINT_TIMING_INFORMATION = false;
typedef char32_t Phoneme; typedef char32_t Phoneme;
typedef int64_t PhonemeId; typedef int64_t PhonemeId;
@ -99,8 +102,25 @@ static std::vector<std::string> available_providers;
static void static void
set_available_providers() { set_available_providers() {
if (!available_providers.empty()) return; static bool providers_set = false;
available_providers = Ort::GetAvailableProviders(); if (providers_set || !USE_GPU) return;
providers_set = true;
Ort::SessionOptions opts;
opts.DisableCpuMemArena();
opts.DisableMemPattern();
opts.DisableProfiling();
Ort::Env ort_env{ORT_LOGGING_LEVEL_WARNING, "piper"};
ort_env.DisableTelemetryEvents();
for (const std::string& s : Ort::GetAvailableProviders()) {
if (s == "CPUExecutionProvider") continue;
std::unordered_map<std::string, std::string> provider_options;
try {
opts.AppendExecutionProvider(s, provider_options);
available_providers.push_back(s);
} catch (const Ort::Exception& e) {
}
}
sort_providers_by_priority(available_providers, PRIORITY_ORDER); sort_providers_by_priority(available_providers, PRIORITY_ORDER);
} }
@ -182,6 +202,11 @@ phonemize(PyObject *self, PyObject *pytext) {
return phonemes_and_terminators; return phonemes_and_terminators;
} }
static
long long now() {
return std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::steady_clock::now().time_since_epoch()).count();
}
static PyObject* static PyObject*
set_voice(PyObject *self, PyObject *args) { set_voice(PyObject *self, PyObject *args) {
PyObject *cfg; PyObject *pymp; PyObject *cfg; PyObject *pymp;
@ -228,13 +253,19 @@ set_voice(PyObject *self, PyObject *args) {
// Load onnx model // Load onnx model
Py_BEGIN_ALLOW_THREADS; Py_BEGIN_ALLOW_THREADS;
static Ort::SessionOptions opts; Ort::SessionOptions opts;
opts.DisableCpuMemArena(); opts.DisableCpuMemArena();
opts.DisableMemPattern(); opts.DisableMemPattern();
opts.DisableProfiling(); opts.DisableProfiling();
Ort::Env ort_env{ORT_LOGGING_LEVEL_WARNING, "piper"}; Ort::Env ort_env{ORT_LOGGING_LEVEL_WARNING, "piper"};
ort_env.DisableTelemetryEvents(); ort_env.DisableTelemetryEvents();
for (const auto& p : available_providers) {
std::unordered_map<std::string, std::string> provider_options;
opts.AppendExecutionProvider(p, provider_options);
}
session.reset(); session.reset();
long long st;
if (PRINT_TIMING_INFORMATION) st = now();
#ifdef _WIN32 #ifdef _WIN32
wchar_t *model_path = PyUnicode_AsWideCharString(pymp, NULL); wchar_t *model_path = PyUnicode_AsWideCharString(pymp, NULL);
if (!model_path) return NULL; if (!model_path) return NULL;
@ -243,6 +274,7 @@ set_voice(PyObject *self, PyObject *args) {
#else #else
session = std::make_unique<Ort::Session>(Ort::Session(ort_env, PyUnicode_AsUTF8(pymp), opts)); session = std::make_unique<Ort::Session>(Ort::Session(ort_env, PyUnicode_AsUTF8(pymp), opts));
#endif #endif
if (PRINT_TIMING_INFORMATION) { printf("model loading time: %f\n", (now()-st) / 1e9); fflush(stdout); }
Py_END_ALLOW_THREADS; Py_END_ALLOW_THREADS;
@ -385,9 +417,12 @@ next(PyObject *self, PyObject *args) {
// Infer // Infer
Ort::RunOptions ro; Ort::RunOptions ro;
long long st;
if (PRINT_TIMING_INFORMATION) st = now();
output_tensors = session->Run( output_tensors = session->Run(
ro, input_names.data(), input_tensors.data(), ro, input_names.data(), input_tensors.data(),
input_tensors.size(), output_names.data(), output_names.size()); input_tensors.size(), output_names.data(), output_names.size());
if (PRINT_TIMING_INFORMATION) { printf("model run time: %f\n", (now()-st) / 1e9); fflush(stdout); }
Py_END_ALLOW_THREADS; Py_END_ALLOW_THREADS;
if ((output_tensors.size() != 1) || (!output_tensors.front().IsTensor())) { if ((output_tensors.size() != 1) || (!output_tensors.front().IsTensor())) {