73 lines
2.0 KiB
Python

from __future__ import annotations
from pathlib import Path
from typing import Any, List
import numpy as np
import onnxruntime as ort
from numpy.typing import NDArray
from rknnlite.api import RKNNLite
from app.models.constants import SUPPORTED_PROVIDERS
from app.schemas import SessionNode
from ..config import log, settings
class RknnSession:
def __init__(self, model_path: Path | str):
self.model_path = Path(model_path)
self.ort_model_path = str(self.model_path).replace(".rknn", ".onnx")
self.rknn = RKNNLite()
log.info(f"Loading RKNN model from {self.model_path}")
self.ort_session = ort.InferenceSession(
self.ort_model_path,
)
# ret = self.rknn.load_onnx(self.model_path)
print('--> Load RKNN model')
ret = self.rknn.load_rknn(self.model_path.as_posix())
if ret != 0:
raise RuntimeError("Failed to load RKNN model")
ret = self.rknn.init_runtime()
if ret != 0:
raise RuntimeError("Failed to initialize RKNN runtime")
def get_inputs(self) -> list[SessionNode]:
inputs: list[SessionNode] = self.ort_session.get_inputs()
return inputs
def get_outputs(self) -> list[SessionNode]:
outputs: list[SessionNode] = self.ort_session.get_outputs()
return outputs
def run(
self,
output_names: list[str] | None,
input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
run_options: Any = None,
):
input_data = [np.ascontiguousarray(v) for v in input_feed.values()][0]
# log.info(f"Running inference on RKNN model")
outputs = self.rknn.inference(inputs=[input_data])
if not outputs:
outputs = self.rknn.inference(inputs=[input_data], data_format='nchw')
# log.info("inputs:")
# log.info(input_data)
# log.info("outputs:")
# log.info(outputs)
# log.info("RKNN END")
return outputs