2025-01-13 17:08:16 +08:00

72 lines
2.3 KiB
Python

from __future__ import annotations
from pathlib import Path
from typing import Any
import numpy as np
import onnxruntime as ort
from numpy.typing import NDArray
from app.schemas import SessionNode
from rknn.rknnpool import rknnPoolExecutor, soc_name
from ..config import log, settings
def runInfrence(rknn_lite, input):
outputs = rknn_lite.inference(inputs=[input], data_format="nchw")
return outputs
class RknnSession:
def __init__(self, model_path: Path | str):
self.model_path = Path(str(model_path).replace("model", soc_name))
self.ort_model_path = Path(str(self.model_path).replace(f"{soc_name}.rknn", "model.onnx"))
self.inputs = None
self.outputs = None
if "textual" in str(self.model_path):
self.tpe = settings.rknn_textual_threads
elif "visual" in str(self.model_path):
self.tpe = settings.rknn_visual_threads
else:
self.tpe = settings.rknn_facial_detection_threads
log.info(f"Loading RKNN model from {self.model_path} with {self.tpe} threads.")
self.rknnpool = rknnPoolExecutor(rknnModel=self.model_path.as_posix(), TPEs=self.tpe, func=runInfrence)
def __del__(self):
self.rknnpool.release()
def get_inputs(self) -> list[SessionNode]:
if not self.inputs:
self.ort_session = ort.InferenceSession(
self.ort_model_path.as_posix(),
)
self.inputs = self.ort_session.get_inputs()
self.outputs = self.ort_session.get_outputs()
del self.ort_session
return self.inputs
def get_outputs(self) -> list[SessionNode]:
if not self.outputs:
self.ort_session = ort.InferenceSession(
self.ort_model_path.as_posix(),
)
self.inputs = self.ort_session.get_inputs()
self.outputs = self.ort_session.get_outputs()
del self.ort_session
return self.outputs
def run(
self,
output_names: list[str] | None,
input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
run_options: Any = None,
):
input_data = [np.ascontiguousarray(v) for v in input_feed.values()][0]
self.rknnpool.put(input_data)
outputs = self.rknnpool.get()[0]
return outputs