fix types and ignored pattern

This commit is contained in:
yoni13 2025-01-18 17:17:48 +08:00
parent 58f1cc92d7
commit 32f3707e52
3 changed files with 5 additions and 3 deletions

View File

@ -69,6 +69,7 @@ class InferenceModel(ABC):
def _download(self) -> None: def _download(self) -> None:
ignored_patterns: dict[ModelFormat, list[str]] = { ignored_patterns: dict[ModelFormat, list[str]] = {
ModelFormat.ONNX: ["*.armnn", "*.rknn"],
ModelFormat.ARMNN: ["*.rknn"], ModelFormat.ARMNN: ["*.rknn"],
ModelFormat.RKNN: ["*.armnn"], ModelFormat.RKNN: ["*.armnn"],
} }

View File

@ -355,7 +355,7 @@ class TestRknnSession:
RknnSession(model_path) RknnSession(model_path)
rknn_session.assert_called_once_with( rknn_session.assert_called_once_with(
rknnModel=Path(str(model_path).replace("model", "rk3566")).as_posix(), TPEs=tpe, func=runInference rknnModel=Path(str(model_path).replace("model", "rk3566")).as_posix(), tpes=tpe, func=runInference
) )
info.assert_has_calls( info.assert_has_calls(

View File

@ -5,6 +5,7 @@ import os
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from queue import Queue from queue import Queue
import numpy as np import numpy as np
from typing import Callable
from numpy.typing import NDArray from numpy.typing import NDArray
from app.config import log from app.config import log
@ -31,7 +32,7 @@ except (FileNotFoundError, ImportError):
soc_name = None soc_name = None
def init_rknn(rknnModel, id) -> RKNNLite: def init_rknn(rknnModel, id) -> Callable:
if not is_available: if not is_available:
raise RuntimeError("rknn is not available!") raise RuntimeError("rknn is not available!")
rknn_lite = RKNNLite() rknn_lite = RKNNLite()
@ -59,7 +60,7 @@ def init_rknn(rknnModel, id) -> RKNNLite:
return rknn_lite return rknn_lite
def init_rknns(rknnModel, tpes) -> list[RKNNLite]: def init_rknns(rknnModel, tpes) -> list[Callable]:
rknn_list = [] rknn_list = []
for i in range(tpes): for i in range(tpes):
rknn_list.append(init_rknn(rknnModel, i % 3)) rknn_list.append(init_rknn(rknnModel, i % 3))