from __future__ import annotations from pathlib import Path from typing import Any, NamedTuple import numpy as np from numpy.typing import NDArray from immich_ml.config import log, settings from immich_ml.schemas import SessionNode from .rknnpool import RknnPoolExecutor, is_available, soc_name is_available = is_available and settings.rknn model_prefix = Path("rknpu") / soc_name if is_available and soc_name is not None else None def run_inference(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArray[np.float32]]: outputs: list[NDArray[np.float32]] = rknn_lite.inference(inputs=input, data_format="nchw") return outputs input_output_mapping: dict[str, dict[str, Any]] = { "detection": { "input": {"norm_tensor:0": (1, 3, 640, 640)}, "output": { "norm_tensor:1": (12800, 1), "norm_tensor:2": (3200, 1), "norm_tensor:3": (800, 1), "norm_tensor:4": (12800, 4), "norm_tensor:5": (3200, 4), "norm_tensor:6": (800, 4), "norm_tensor:7": (12800, 10), "norm_tensor:8": (3200, 10), "norm_tensor:9": (800, 10), }, }, "recognition": {"input": {"norm_tensor:0": (1, 3, 112, 112)}, "output": {"norm_tensor:1": (1, 512)}}, } class RknnSession: def __init__(self, model_path: Path) -> None: self.model_type = "detection" if "detection" in model_path.parts else "recognition" self.tpe = settings.rknn_threads log.info(f"Loading RKNN model from {model_path} with {self.tpe} threads.") self.rknnpool = RknnPoolExecutor(model_path=model_path.as_posix(), tpes=self.tpe, func=run_inference) log.info(f"Loaded RKNN model from {model_path} with {self.tpe} threads.") def get_inputs(self) -> list[SessionNode]: return [RknnNode(name=k, shape=v) for k, v in input_output_mapping[self.model_type]["input"].items()] def get_outputs(self) -> list[SessionNode]: return [RknnNode(name=k, shape=v) for k, v in input_output_mapping[self.model_type]["output"].items()] def run( self, output_names: list[str] | None, input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]], run_options: Any = None, ) -> list[NDArray[np.float32]]: input_data: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()] self.rknnpool.put(input_data) res = self.rknnpool.get() if res is None: raise RuntimeError("RKNN inference failed!") return res class RknnNode(NamedTuple): name: str | None shape: tuple[int, ...] __all__ = ["RknnSession", "RknnNode", "is_available", "soc_name", "model_prefix"]