Load model by SOC name

This commit is contained in:
yoni13 2025-01-13 17:08:16 +08:00
parent daf886088a
commit ebdfe1b7b6
2 changed files with 7 additions and 5 deletions

View File

@ -8,7 +8,7 @@ import onnxruntime as ort
from numpy.typing import NDArray
from app.schemas import SessionNode
from rknn.rknnpool import rknnPoolExecutor
from rknn.rknnpool import rknnPoolExecutor, soc_name
from ..config import log, settings
@ -21,8 +21,8 @@ def runInfrence(rknn_lite, input):
class RknnSession:
def __init__(self, model_path: Path | str):
self.model_path = Path(model_path)
self.ort_model_path = str(self.model_path).replace(".rknn", ".onnx")
self.model_path = Path(str(model_path).replace("model", soc_name))
self.ort_model_path = Path(str(self.model_path).replace(f"{soc_name}.rknn", "model.onnx"))
self.inputs = None
self.outputs = None
@ -42,7 +42,7 @@ class RknnSession:
def get_inputs(self) -> list[SessionNode]:
if not self.inputs:
self.ort_session = ort.InferenceSession(
self.ort_model_path,
self.ort_model_path.as_posix(),
)
self.inputs = self.ort_session.get_inputs()
self.outputs = self.ort_session.get_outputs()
@ -52,7 +52,7 @@ class RknnSession:
def get_outputs(self) -> list[SessionNode]:
if not self.outputs:
self.ort_session = ort.InferenceSession(
self.ort_model_path,
self.ort_model_path.as_posix(),
)
self.inputs = self.ort_session.get_inputs()
self.outputs = self.ort_session.get_outputs()

View File

@ -20,9 +20,11 @@ try:
break
else:
is_available = False
soc_name = None
is_available = os.path.exists("/sys/kernel/debug/rknpu/load")
except (FileNotFoundError, ImportError):
is_available = False
soc_name = None
def initRKNN(rknnModel="./rknnModel/yolov5s.rknn", id=0):