From ebdfe1b7b6f5c006c4752fcf7ad7a3d61faffa58 Mon Sep 17 00:00:00 2001 From: yoni13 Date: Mon, 13 Jan 2025 17:08:16 +0800 Subject: [PATCH] Load model by SOC name --- machine-learning/app/sessions/rknn.py | 10 +++++----- machine-learning/rknn/rknnpool.py | 2 ++ 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/machine-learning/app/sessions/rknn.py b/machine-learning/app/sessions/rknn.py index 4e5e85a242..421a6971f9 100644 --- a/machine-learning/app/sessions/rknn.py +++ b/machine-learning/app/sessions/rknn.py @@ -8,7 +8,7 @@ import onnxruntime as ort from numpy.typing import NDArray from app.schemas import SessionNode -from rknn.rknnpool import rknnPoolExecutor +from rknn.rknnpool import rknnPoolExecutor, soc_name from ..config import log, settings @@ -21,8 +21,8 @@ def runInfrence(rknn_lite, input): class RknnSession: def __init__(self, model_path: Path | str): - self.model_path = Path(model_path) - self.ort_model_path = str(self.model_path).replace(".rknn", ".onnx") + self.model_path = Path(str(model_path).replace("model", soc_name)) + self.ort_model_path = Path(str(self.model_path).replace(f"{soc_name}.rknn", "model.onnx")) self.inputs = None self.outputs = None @@ -42,7 +42,7 @@ class RknnSession: def get_inputs(self) -> list[SessionNode]: if not self.inputs: self.ort_session = ort.InferenceSession( - self.ort_model_path, + self.ort_model_path.as_posix(), ) self.inputs = self.ort_session.get_inputs() self.outputs = self.ort_session.get_outputs() @@ -52,7 +52,7 @@ class RknnSession: def get_outputs(self) -> list[SessionNode]: if not self.outputs: self.ort_session = ort.InferenceSession( - self.ort_model_path, + self.ort_model_path.as_posix(), ) self.inputs = self.ort_session.get_inputs() self.outputs = self.ort_session.get_outputs() diff --git a/machine-learning/rknn/rknnpool.py b/machine-learning/rknn/rknnpool.py index ea5a2403e7..64e4af9f7c 100644 --- a/machine-learning/rknn/rknnpool.py +++ b/machine-learning/rknn/rknnpool.py @@ -20,9 +20,11 @@ try: break else: is_available = False + soc_name = None is_available = os.path.exists("/sys/kernel/debug/rknpu/load") except (FileNotFoundError, ImportError): is_available = False + soc_name = None def initRKNN(rknnModel="./rknnModel/yolov5s.rknn", id=0):