diff --git a/machine-learning/app/sessions/rknn.py b/machine-learning/app/sessions/rknn.py index 3d234fc6fa..fe490af074 100644 --- a/machine-learning/app/sessions/rknn.py +++ b/machine-learning/app/sessions/rknn.py @@ -18,31 +18,29 @@ def runInference(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArr return outputs -input_output_mapping = { - "buffalo_l": { - "detection": { - "input": {"norm_tensor:0": (1, 3, 640, 640)}, - "output": { - "norm_tensor:1": (12800, 1), - "norm_tensor:2": (3200, 1), - "norm_tensor:3": (800, 1), - "norm_tensor:4": (12800, 4), - "norm_tensor:5": (3200, 4), - "norm_tensor:6": (800, 4), - "norm_tensor:7": (12800, 10), - "norm_tensor:8": (3200, 10), - "norm_tensor:9": (800, 10), - }, +input_output_mapping: dict[str, dict[str, Any]] = { + "detection": { + "input": {"norm_tensor:0": (1, 3, 640, 640)}, + "output": { + "norm_tensor:1": (12800, 1), + "norm_tensor:2": (3200, 1), + "norm_tensor:3": (800, 1), + "norm_tensor:4": (12800, 4), + "norm_tensor:5": (3200, 4), + "norm_tensor:6": (800, 4), + "norm_tensor:7": (12800, 10), + "norm_tensor:8": (3200, 10), + "norm_tensor:9": (800, 10), }, - "recognition": {"input": {"norm_tensor:0": (1, 3, 112, 112)}, "output": {"norm_tensor:1": (1, 512)}}, - } + }, + "recognition": {"input": {"norm_tensor:0": (1, 3, 112, 112)}, "output": {"norm_tensor:1": (1, 512)}}, } class RknnSession: def __init__(self, model_path: Path | str): self.model_path = Path(str(model_path).replace("model", soc_name)) - + self.model_type = "detection" if "detection" in self.model_path.as_posix() else "recognition" self.tpe = settings.rknn_threads log.info(f"Loading RKNN model from {self.model_path} with {self.tpe} threads.") @@ -53,24 +51,10 @@ class RknnSession: self.rknnpool.release() def get_inputs(self) -> list[SessionNode]: - for model_name in input_output_mapping: - if model_name in self.model_path.as_posix(): - model_type = "detection" if "detection" in self.model_path.as_posix() else "recognition" - return [ - RknnNode(name=k, shape=v) - for k, v in input_output_mapping[model_name][model_type]["input"].items() - ] - raise ValueError(f"Model {self.model_path} not found in input_output_mapping.") + return [RknnNode(name=k, shape=v) for k, v in input_output_mapping[self.model_type]["input"].items()] def get_outputs(self) -> list[SessionNode]: - for model_name in input_output_mapping: - if model_name in self.model_path.as_posix(): - model_type = "detection" if "detection" in self.model_path.as_posix() else "recognition" - return [ - RknnNode(name=k, shape=v) - for k, v in input_output_mapping[model_name][model_type]["output"].items() - ] - raise ValueError(f"Model {self.model_path} not found in input_output_mapping.") + return [RknnNode(name=k, shape=v) for k, v in input_output_mapping[self.model_type]["output"].items()] def run( self,