diff --git a/machine-learning/rknn/rknnpool.py b/machine-learning/rknn/rknnpool.py index 17c0db021f..14fa06481e 100644 --- a/machine-learning/rknn/rknnpool.py +++ b/machine-learning/rknn/rknnpool.py @@ -32,7 +32,7 @@ except (FileNotFoundError, ImportError): soc_name = None -def init_rknn(rknnModel, id) -> Callable: +def init_rknn(rknnModel) -> Callable: if not is_available: raise RuntimeError("rknn is not available!") rknn_lite = RKNNLite() @@ -41,16 +41,7 @@ def init_rknn(rknnModel, id) -> Callable: raise RuntimeError("Load RKNN rknnModel failed") if soc_name in coremask_supported_socs: - if id == 0: - ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0) - elif id == 1: - ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_1) - elif id == 2: - ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_2) - elif id == -1: - ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0_1_2) - else: - ret = rknn_lite.init_runtime() + ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_AUTO) else: ret = rknn_lite.init_runtime() # Please do not set this parameter on other platforms. @@ -63,7 +54,7 @@ def init_rknn(rknnModel, id) -> Callable: def init_rknns(rknnModel, tpes) -> list[Callable]: rknn_list = [] for i in range(tpes): - rknn_list.append(init_rknn(rknnModel, i % 3)) + rknn_list.append(init_rknn(rknnModel)) return rknn_list