add input outputs

This commit is contained in:
yoni13 2025-01-19 12:10:01 +00:00
parent 2b967ca358
commit ac4ce3ea9c

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, NamedTuple
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
@ -18,6 +18,27 @@ def runInference(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArr
return outputs return outputs
input_output_mapping = {
"buffalo_l": {
"detection": {
"input": {"norm_tensor:0": (1, 3, 640, 640)},
"output": {
"norm_tensor:1": (12800, 1),
"norm_tensor:2": (3200, 1),
"norm_tensor:3": (800, 1),
"norm_tensor:4": (12800, 4),
"norm_tensor:5": (3200, 4),
"norm_tensor:6": (800, 4),
"norm_tensor:7": (12800, 10),
"norm_tensor:8": (3200, 10),
"norm_tensor:9": (800, 10),
},
},
"recognition": {"input": {"norm_tensor:0": (1, 3, 112, 112)}, "output": {"norm_tensor:1": (1, 512)}},
}
}
class RknnSession: class RknnSession:
def __init__(self, model_path: Path | str): def __init__(self, model_path: Path | str):
self.model_path = Path(str(model_path).replace("model", soc_name)) self.model_path = Path(str(model_path).replace("model", soc_name))
@ -32,11 +53,25 @@ class RknnSession:
self.rknnpool.release() self.rknnpool.release()
def get_inputs(self) -> list[SessionNode]: def get_inputs(self) -> list[SessionNode]:
raise NotImplementedError for model_name in input_output_mapping:
if model_name in self.model_path.as_posix():
model_type = "detection" if "detection" in self.model_path.as_posix() else "recognition"
return [
RknnNode(name=k, shape=v)
for k, v in input_output_mapping[model_name][model_type]["input"].items()
]
raise ValueError(f"Model {self.model_path} not found in input_output_mapping.")
def get_outputs(self) -> list[SessionNode]: def get_outputs(self) -> list[SessionNode]:
raise NotImplementedError for model_name in input_output_mapping:
if model_name in self.model_path.as_posix():
model_type = "detection" if "detection" in self.model_path.as_posix() else "recognition"
return [
RknnNode(name=k, shape=v)
for k, v in input_output_mapping[model_name][model_type]["output"].items()
]
raise ValueError(f"Model {self.model_path} not found in input_output_mapping.")
def run( def run(
self, self,
output_names: list[str] | None, output_names: list[str] | None,
@ -47,3 +82,8 @@ class RknnSession:
self.rknnpool.put(input_data) self.rknnpool.put(input_data)
outputs: list[NDArray[np.float32]] = self.rknnpool.get() outputs: list[NDArray[np.float32]] = self.rknnpool.get()
return outputs return outputs
class RknnNode(NamedTuple):
name: str | None
shape: tuple[int, ...]