all infrencing works with 1 max job concurrency

This commit is contained in:
yoni13 2025-01-09 10:38:40 +00:00
parent 082c426e34
commit a94fad543b

View File

@ -4,8 +4,10 @@ from pathlib import Path
from typing import Any, List from typing import Any, List
import numpy as np import numpy as np
import onnxruntime as ort
from numpy.typing import NDArray from numpy.typing import NDArray
from rknnlite.api import RKNNLite # Importing RKNN API from rknnlite.api import RKNNLite
from app.models.constants import SUPPORTED_PROVIDERS from app.models.constants import SUPPORTED_PROVIDERS
from app.schemas import SessionNode from app.schemas import SessionNode
@ -16,55 +18,55 @@ from ..config import log, settings
class RknnSession: class RknnSession:
def __init__(self, model_path: Path | str): def __init__(self, model_path: Path | str):
self.model_path = Path(model_path) self.model_path = Path(model_path)
self.rknn = RKNNLite() # Initialize RKNN object
# self.rknn.config(target_platform='rk3566')
# Load the RKNN model
log.info(f"Loading RKNN model from {self.model_path}")
self._load_model()
def _load_model(self) -> None: self.ort_model_path = str(self.model_path).replace(".rknn", ".onnx")
# ret = self.rknn.load_onnx(self.model_path.as_posix())
self.rknn = RKNNLite()
log.info(f"Loading RKNN model from {self.model_path}")
self.ort_session = ort.InferenceSession(
self.ort_model_path,
)
# ret = self.rknn.load_onnx(self.model_path)
print('--> Load RKNN model') print('--> Load RKNN model')
ret = self.rknn.load_rknn(self.model_path.as_posix()) ret = self.rknn.load_rknn(self.model_path.as_posix())
if ret != 0: if ret != 0:
raise RuntimeError("Failed to load RKNN model") raise RuntimeError("Failed to load RKNN model")
# print('--> Building model')
# ret = self.rknn.build(do_quantization=False)
if ret != 0:
print('Build model failed!')
exit(ret)
ret = self.rknn.init_runtime() ret = self.rknn.init_runtime()
if ret != 0: if ret != 0:
raise RuntimeError("Failed to initialize RKNN runtime") raise RuntimeError("Failed to initialize RKNN runtime")
def get_inputs(self) -> List[SessionNode]: def get_inputs(self) -> list[SessionNode]:
input_attrs = self.rknn.query_inputs() inputs: list[SessionNode] = self.ort_session.get_inputs()
return input_attrs # RKNN does not provide direct SessionNode equivalent return inputs
def get_outputs(self) -> List[SessionNode]: def get_outputs(self) -> list[SessionNode]:
output_attrs = self.rknn.query_outputs() outputs: list[SessionNode] = self.ort_session.get_outputs()
return output_attrs return outputs
def run( def run(
self, self,
output_names: list[str] | None, output_names: list[str] | None,
input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]], input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
run_options: Any = None, run_options: Any = None,
) -> List[NDArray[np.float32]]: ):
inputs = [v for v in input_feed.values()]
# Run inference input_data = [np.ascontiguousarray(v) for v in input_feed.values()][0]
log.debug(f"Running inference on RKNN model")
outputs = self.rknn.inference(inputs=inputs)
# log.info(f"Running inference on RKNN model")
outputs = self.rknn.inference(inputs=[input_data])
if not outputs:
outputs = self.rknn.inference(inputs=[input_data], data_format='nchw')
# log.info("inputs:")
# log.info(input_data)
# log.info("outputs:")
# log.info(outputs)
# log.info("RKNN END")
return outputs return outputs
def release(self) -> None:
log.info("Releasing RKNN resources")
self.rknn.release()
# Example Usage:
# session = RknnSession(model_path="path/to/model.rknn")
# outputs = session.run(input_feed={"input_name": input_data})
# session.release()