diff --git a/machine-learning/app/sessions/rknn.py b/machine-learning/app/sessions/rknn.py index 8df46c1034..c79e669632 100644 --- a/machine-learning/app/sessions/rknn.py +++ b/machine-learning/app/sessions/rknn.py @@ -13,8 +13,8 @@ from rknn.rknnpool import rknnPoolExecutor, soc_name from ..config import log, settings -def runInfrence(rknn_lite: Any, input: NDArray[np.float32]) -> list[NDArray[np.float32]]: - outputs: list[NDArray[np.float32]] = rknn_lite.inference(inputs=[input], data_format="nchw") +def runInfrence(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArray[np.float32]]: + outputs: list[NDArray[np.float32]] = rknn_lite.inference(inputs=input, data_format="nchw") return outputs @@ -66,7 +66,7 @@ class RknnSession: input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]], run_options: Any = None, ) -> list[NDArray[np.float32]]: - input_data: NDArray[np.float32] = np.ascontiguousarray(list(input_feed.values())[0], dtype=np.float32) + input_data: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()] self.rknnpool.put(input_data) - outputs: list[NDArray[np.float32]] = self.rknnpool.get()[0] + outputs: list[NDArray[np.float32]] = self.rknnpool.get() return outputs diff --git a/machine-learning/rknn/rknnpool.py b/machine-learning/rknn/rknnpool.py index 293ef9bd8e..9fd365b13c 100644 --- a/machine-learning/rknn/rknnpool.py +++ b/machine-learning/rknn/rknnpool.py @@ -23,7 +23,7 @@ try: else: is_available = False soc_name = None - is_available = os.path.exists("/sys/kernel/debug/rknpu/load") + is_available = is_available and os.path.exists("/sys/kernel/debug/rknpu/load") except (FileNotFoundError, ImportError): is_available = False soc_name = None @@ -79,9 +79,9 @@ class rknnPoolExecutor: def get(self) -> list[list[NDArray[np.float32]], bool]: if self.queue.empty(): - return None, False + return None fut = self.queue.get() - return fut.result(), True + return fut.result() def release(self) -> None: self.pool.shutdown()