From efaf70eb9d6ade3f99238765c87375e36b1690f2 Mon Sep 17 00:00:00 2001 From: yoni13 Date: Sun, 12 Jan 2025 01:02:16 +0800 Subject: [PATCH] Set running threads from env --- machine-learning/app/config.py | 3 +++ machine-learning/app/sessions/rknn.py | 10 ++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/machine-learning/app/config.py b/machine-learning/app/config.py index dcc60c46fc..e6d63368b2 100644 --- a/machine-learning/app/config.py +++ b/machine-learning/app/config.py @@ -45,6 +45,9 @@ class Settings(BaseSettings): ann_fp16_turbo: bool = False ann_tuning_level: int = 2 rknn: bool = True + rknn_textual_threads: int = 1 + rknn_visual_threads: int = 2 + rknn_facial_detection_threads: int = 2 preload: PreloadModelData | None = None max_batch_size: MaxBatchSize | None = None diff --git a/machine-learning/app/sessions/rknn.py b/machine-learning/app/sessions/rknn.py index ddc93ecbf7..0606cef4c2 100644 --- a/machine-learning/app/sessions/rknn.py +++ b/machine-learning/app/sessions/rknn.py @@ -10,7 +10,7 @@ from numpy.typing import NDArray from app.schemas import SessionNode from rknn.rknnpool import rknnPoolExecutor -from ..config import log +from ..config import log, settings def runInfrence(rknn_lite, input): @@ -23,7 +23,13 @@ class RknnSession: def __init__(self, model_path: Path | str): self.model_path = Path(model_path) self.ort_model_path = str(self.model_path).replace(".rknn", ".onnx") - self.tpe = 1 if "textual" in str(self.model_path) else 2 + + if 'textual' in self.model_path.name: + self.tpe = settings.rknn_textual_threads + elif 'visual' in self.model_path.name: + self.tpe = settings.rknn_visual_threads + else: + self.tpe = settings.rknn_facial_detection_threads log.info(f"Loading RKNN model from {self.model_path} with {self.tpe} threads.") self.rknnpool = rknnPoolExecutor(rknnModel=self.model_path.as_posix(), TPEs=self.tpe, func=runInfrence)