From ac4ce3ea9c72a580258057ac7138f409dd4b5d48 Mon Sep 17 00:00:00 2001 From: yoni13 Date: Sun, 19 Jan 2025 12:10:01 +0000 Subject: [PATCH] add input outputs --- machine-learning/app/sessions/rknn.py | 48 ++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/machine-learning/app/sessions/rknn.py b/machine-learning/app/sessions/rknn.py index 4121839afb..3d234fc6fa 100644 --- a/machine-learning/app/sessions/rknn.py +++ b/machine-learning/app/sessions/rknn.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any +from typing import Any, NamedTuple import numpy as np from numpy.typing import NDArray @@ -18,6 +18,27 @@ def runInference(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArr return outputs +input_output_mapping = { + "buffalo_l": { + "detection": { + "input": {"norm_tensor:0": (1, 3, 640, 640)}, + "output": { + "norm_tensor:1": (12800, 1), + "norm_tensor:2": (3200, 1), + "norm_tensor:3": (800, 1), + "norm_tensor:4": (12800, 4), + "norm_tensor:5": (3200, 4), + "norm_tensor:6": (800, 4), + "norm_tensor:7": (12800, 10), + "norm_tensor:8": (3200, 10), + "norm_tensor:9": (800, 10), + }, + }, + "recognition": {"input": {"norm_tensor:0": (1, 3, 112, 112)}, "output": {"norm_tensor:1": (1, 512)}}, + } +} + + class RknnSession: def __init__(self, model_path: Path | str): self.model_path = Path(str(model_path).replace("model", soc_name)) @@ -32,11 +53,25 @@ class RknnSession: self.rknnpool.release() def get_inputs(self) -> list[SessionNode]: - raise NotImplementedError + for model_name in input_output_mapping: + if model_name in self.model_path.as_posix(): + model_type = "detection" if "detection" in self.model_path.as_posix() else "recognition" + return [ + RknnNode(name=k, shape=v) + for k, v in input_output_mapping[model_name][model_type]["input"].items() + ] + raise ValueError(f"Model {self.model_path} not found in input_output_mapping.") def get_outputs(self) -> list[SessionNode]: - raise NotImplementedError - + for model_name in input_output_mapping: + if model_name in self.model_path.as_posix(): + model_type = "detection" if "detection" in self.model_path.as_posix() else "recognition" + return [ + RknnNode(name=k, shape=v) + for k, v in input_output_mapping[model_name][model_type]["output"].items() + ] + raise ValueError(f"Model {self.model_path} not found in input_output_mapping.") + def run( self, output_names: list[str] | None, @@ -47,3 +82,8 @@ class RknnSession: self.rknnpool.put(input_data) outputs: list[NDArray[np.float32]] = self.rknnpool.get() return outputs + + +class RknnNode(NamedTuple): + name: str | None + shape: tuple[int, ...]