From 32f3707e5255efe2b4c4b5fe4a7a2b3bb416cbda Mon Sep 17 00:00:00 2001 From: yoni13 Date: Sat, 18 Jan 2025 17:17:48 +0800 Subject: [PATCH] fix types and ignored pattern --- machine-learning/app/models/base.py | 1 + machine-learning/app/test_main.py | 2 +- machine-learning/rknn/rknnpool.py | 5 +++-- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/machine-learning/app/models/base.py b/machine-learning/app/models/base.py index bea3684c66..571d33de22 100644 --- a/machine-learning/app/models/base.py +++ b/machine-learning/app/models/base.py @@ -69,6 +69,7 @@ class InferenceModel(ABC): def _download(self) -> None: ignored_patterns: dict[ModelFormat, list[str]] = { + ModelFormat.ONNX: ["*.armnn", "*.rknn"], ModelFormat.ARMNN: ["*.rknn"], ModelFormat.RKNN: ["*.armnn"], } diff --git a/machine-learning/app/test_main.py b/machine-learning/app/test_main.py index a7648c5430..29a5b50e1f 100644 --- a/machine-learning/app/test_main.py +++ b/machine-learning/app/test_main.py @@ -355,7 +355,7 @@ class TestRknnSession: RknnSession(model_path) rknn_session.assert_called_once_with( - rknnModel=Path(str(model_path).replace("model", "rk3566")).as_posix(), TPEs=tpe, func=runInference + rknnModel=Path(str(model_path).replace("model", "rk3566")).as_posix(), tpes=tpe, func=runInference ) info.assert_has_calls( diff --git a/machine-learning/rknn/rknnpool.py b/machine-learning/rknn/rknnpool.py index e8a3a8bcea..ed2e48c4a6 100644 --- a/machine-learning/rknn/rknnpool.py +++ b/machine-learning/rknn/rknnpool.py @@ -5,6 +5,7 @@ import os from concurrent.futures import ThreadPoolExecutor from queue import Queue import numpy as np +from typing import Callable from numpy.typing import NDArray from app.config import log @@ -31,7 +32,7 @@ except (FileNotFoundError, ImportError): soc_name = None -def init_rknn(rknnModel, id) -> RKNNLite: +def init_rknn(rknnModel, id) -> Callable: if not is_available: raise RuntimeError("rknn is not available!") rknn_lite = RKNNLite() @@ -59,7 +60,7 @@ def init_rknn(rknnModel, id) -> RKNNLite: return rknn_lite -def init_rknns(rknnModel, tpes) -> list[RKNNLite]: +def init_rknns(rknnModel, tpes) -> list[Callable]: rknn_list = [] for i in range(tpes): rknn_list.append(init_rknn(rknnModel, i % 3))