From a94fad543b6ad7960dba999f01cdc48d6d91a20f Mon Sep 17 00:00:00 2001 From: yoni13 Date: Thu, 9 Jan 2025 10:38:40 +0000 Subject: [PATCH] all infrencing works with 1 max job concurrency --- machine-learning/app/sessions/rknn.py | 72 ++++++++++++++------------- 1 file changed, 37 insertions(+), 35 deletions(-) diff --git a/machine-learning/app/sessions/rknn.py b/machine-learning/app/sessions/rknn.py index 4cd1de3327..356a78c30b 100644 --- a/machine-learning/app/sessions/rknn.py +++ b/machine-learning/app/sessions/rknn.py @@ -4,8 +4,10 @@ from pathlib import Path from typing import Any, List import numpy as np +import onnxruntime as ort from numpy.typing import NDArray -from rknnlite.api import RKNNLite # Importing RKNN API +from rknnlite.api import RKNNLite + from app.models.constants import SUPPORTED_PROVIDERS from app.schemas import SessionNode @@ -16,55 +18,55 @@ from ..config import log, settings class RknnSession: def __init__(self, model_path: Path | str): self.model_path = Path(model_path) - self.rknn = RKNNLite() # Initialize RKNN object -# self.rknn.config(target_platform='rk3566') - # Load the RKNN model - log.info(f"Loading RKNN model from {self.model_path}") - self._load_model() - def _load_model(self) -> None: -# ret = self.rknn.load_onnx(self.model_path.as_posix()) + self.ort_model_path = str(self.model_path).replace(".rknn", ".onnx") + + self.rknn = RKNNLite() + + log.info(f"Loading RKNN model from {self.model_path}") + + self.ort_session = ort.InferenceSession( + self.ort_model_path, + ) + + # ret = self.rknn.load_onnx(self.model_path) print('--> Load RKNN model') ret = self.rknn.load_rknn(self.model_path.as_posix()) if ret != 0: raise RuntimeError("Failed to load RKNN model") - # print('--> Building model') - # ret = self.rknn.build(do_quantization=False) - if ret != 0: - print('Build model failed!') - exit(ret) ret = self.rknn.init_runtime() if ret != 0: - raise RuntimeError("Failed to initialize RKNN runtime") + raise RuntimeError("Failed to initialize RKNN runtime") - def get_inputs(self) -> List[SessionNode]: - input_attrs = self.rknn.query_inputs() - return input_attrs # RKNN does not provide direct SessionNode equivalent + def get_inputs(self) -> list[SessionNode]: + inputs: list[SessionNode] = self.ort_session.get_inputs() + return inputs - def get_outputs(self) -> List[SessionNode]: - output_attrs = self.rknn.query_outputs() - return output_attrs + def get_outputs(self) -> list[SessionNode]: + outputs: list[SessionNode] = self.ort_session.get_outputs() + return outputs def run( self, output_names: list[str] | None, input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]], - run_options: Any = None, - ) -> List[NDArray[np.float32]]: - inputs = [v for v in input_feed.values()] + run_options: Any = None, + ): - # Run inference - log.debug(f"Running inference on RKNN model") - outputs = self.rknn.inference(inputs=inputs) + input_data = [np.ascontiguousarray(v) for v in input_feed.values()][0] + + + # log.info(f"Running inference on RKNN model") + + outputs = self.rknn.inference(inputs=[input_data]) + if not outputs: + outputs = self.rknn.inference(inputs=[input_data], data_format='nchw') + + # log.info("inputs:") + # log.info(input_data) + # log.info("outputs:") + # log.info(outputs) + # log.info("RKNN END") return outputs - def release(self) -> None: - log.info("Releasing RKNN resources") - self.rknn.release() - - -# Example Usage: -# session = RknnSession(model_path="path/to/model.rknn") -# outputs = session.run(input_feed={"input_name": input_data}) -# session.release()