diff --git a/machine-learning/app/sessions/rknn.py b/machine-learning/app/sessions/rknn.py index 9607a65c41..4121839afb 100644 --- a/machine-learning/app/sessions/rknn.py +++ b/machine-learning/app/sessions/rknn.py @@ -4,7 +4,6 @@ from pathlib import Path from typing import Any import numpy as np -import onnxruntime as ort from numpy.typing import NDArray from app.schemas import SessionNode @@ -22,7 +21,6 @@ def runInference(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArr class RknnSession: def __init__(self, model_path: Path | str): self.model_path = Path(str(model_path).replace("model", soc_name)) - self.ort_model_path = Path(str(self.model_path).replace(f"{soc_name}.rknn", "model.onnx")) self.tpe = settings.rknn_threads @@ -33,28 +31,12 @@ class RknnSession: def __del__(self) -> None: self.rknnpool.release() - def _load_ort_session(self) -> None: - self.ort_session = ort.InferenceSession( - self.ort_model_path.as_posix(), - ) - self.inputs: list[SessionNode] = self.ort_session.get_inputs() - self.outputs: list[SessionNode] = self.ort_session.get_outputs() - del self.ort_session - def get_inputs(self) -> list[SessionNode]: - try: - return self.inputs - except AttributeError: - self._load_ort_session() - return self.inputs + raise NotImplementedError def get_outputs(self) -> list[SessionNode]: - try: - return self.outputs - except AttributeError: - self._load_ort_session() - return self.outputs - + raise NotImplementedError + def run( self, output_names: list[str] | None,