mirror of
https://github.com/kovidgoyal/calibre.git
synced 2025-07-31 14:33:54 -04:00
More work on using GPU to run inferencing
This commit is contained in:
parent
c7907e2081
commit
fa582d8f26
@ -137,7 +137,10 @@ def freeze(env, ext_dir, incdir):
|
||||
# piper
|
||||
for x in ('espeak-ng-data',):
|
||||
shutil.copytree(os.path.join(PREFIX, 'share', x), os.path.join(env.share_dir, x))
|
||||
copybin(os.path.join(libdir, "onnxruntime.dll"))
|
||||
for f in glob.glob(os.path.join(libdir, 'onnxruntime*.dll')):
|
||||
copybin(f)
|
||||
for f in glob.glob(os.path.join(libdir, 'DirectML*.dll')):
|
||||
copybin(f)
|
||||
|
||||
for f in glob.glob(os.path.join(bindir, '*.dll')):
|
||||
if re.search(r'(easylzma|icutest)', f.lower()) is None:
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include <cstdint>
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <chrono>
|
||||
#ifdef _WIN32
|
||||
#define ORT_DLL_IMPORT
|
||||
#endif
|
||||
@ -35,6 +36,8 @@
|
||||
(45 | CLAUSE_INTONATION_EXCLAMATION | CLAUSE_TYPE_SENTENCE)
|
||||
#define CLAUSE_COLON (30 | CLAUSE_INTONATION_FULL_STOP | CLAUSE_TYPE_CLAUSE)
|
||||
#define CLAUSE_SEMICOLON (30 | CLAUSE_INTONATION_COMMA | CLAUSE_TYPE_CLAUSE)
|
||||
static const bool USE_GPU = false;
|
||||
static const bool PRINT_TIMING_INFORMATION = false;
|
||||
|
||||
typedef char32_t Phoneme;
|
||||
typedef int64_t PhonemeId;
|
||||
@ -99,8 +102,25 @@ static std::vector<std::string> available_providers;
|
||||
|
||||
static void
|
||||
set_available_providers() {
|
||||
if (!available_providers.empty()) return;
|
||||
available_providers = Ort::GetAvailableProviders();
|
||||
static bool providers_set = false;
|
||||
if (providers_set || !USE_GPU) return;
|
||||
providers_set = true;
|
||||
Ort::SessionOptions opts;
|
||||
opts.DisableCpuMemArena();
|
||||
opts.DisableMemPattern();
|
||||
opts.DisableProfiling();
|
||||
Ort::Env ort_env{ORT_LOGGING_LEVEL_WARNING, "piper"};
|
||||
ort_env.DisableTelemetryEvents();
|
||||
|
||||
for (const std::string& s : Ort::GetAvailableProviders()) {
|
||||
if (s == "CPUExecutionProvider") continue;
|
||||
std::unordered_map<std::string, std::string> provider_options;
|
||||
try {
|
||||
opts.AppendExecutionProvider(s, provider_options);
|
||||
available_providers.push_back(s);
|
||||
} catch (const Ort::Exception& e) {
|
||||
}
|
||||
}
|
||||
sort_providers_by_priority(available_providers, PRIORITY_ORDER);
|
||||
}
|
||||
|
||||
@ -182,6 +202,11 @@ phonemize(PyObject *self, PyObject *pytext) {
|
||||
return phonemes_and_terminators;
|
||||
}
|
||||
|
||||
static
|
||||
long long now() {
|
||||
return std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::steady_clock::now().time_since_epoch()).count();
|
||||
}
|
||||
|
||||
static PyObject*
|
||||
set_voice(PyObject *self, PyObject *args) {
|
||||
PyObject *cfg; PyObject *pymp;
|
||||
@ -228,13 +253,19 @@ set_voice(PyObject *self, PyObject *args) {
|
||||
|
||||
// Load onnx model
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
static Ort::SessionOptions opts;
|
||||
Ort::SessionOptions opts;
|
||||
opts.DisableCpuMemArena();
|
||||
opts.DisableMemPattern();
|
||||
opts.DisableProfiling();
|
||||
Ort::Env ort_env{ORT_LOGGING_LEVEL_WARNING, "piper"};
|
||||
ort_env.DisableTelemetryEvents();
|
||||
for (const auto& p : available_providers) {
|
||||
std::unordered_map<std::string, std::string> provider_options;
|
||||
opts.AppendExecutionProvider(p, provider_options);
|
||||
}
|
||||
session.reset();
|
||||
long long st;
|
||||
if (PRINT_TIMING_INFORMATION) st = now();
|
||||
#ifdef _WIN32
|
||||
wchar_t *model_path = PyUnicode_AsWideCharString(pymp, NULL);
|
||||
if (!model_path) return NULL;
|
||||
@ -243,6 +274,7 @@ set_voice(PyObject *self, PyObject *args) {
|
||||
#else
|
||||
session = std::make_unique<Ort::Session>(Ort::Session(ort_env, PyUnicode_AsUTF8(pymp), opts));
|
||||
#endif
|
||||
if (PRINT_TIMING_INFORMATION) { printf("model loading time: %f\n", (now()-st) / 1e9); fflush(stdout); }
|
||||
Py_END_ALLOW_THREADS;
|
||||
|
||||
|
||||
@ -385,9 +417,12 @@ next(PyObject *self, PyObject *args) {
|
||||
|
||||
// Infer
|
||||
Ort::RunOptions ro;
|
||||
long long st;
|
||||
if (PRINT_TIMING_INFORMATION) st = now();
|
||||
output_tensors = session->Run(
|
||||
ro, input_names.data(), input_tensors.data(),
|
||||
input_tensors.size(), output_names.data(), output_names.size());
|
||||
if (PRINT_TIMING_INFORMATION) { printf("model run time: %f\n", (now()-st) / 1e9); fflush(stdout); }
|
||||
Py_END_ALLOW_THREADS;
|
||||
|
||||
if ((output_tensors.size() != 1) || (!output_tensors.front().IsTensor())) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user