From 4db85a895409aebedfb8d885f52b3c03b562eab2 Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Fri, 14 Mar 2025 15:30:11 -0400 Subject: [PATCH] more linting --- machine-learning/app/sessions/rknn/__init__.py | 11 ++++++++--- machine-learning/app/sessions/rknn/rknnpool.py | 18 ++++++++++++------ machine-learning/app/test_main.py | 6 +++--- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/machine-learning/app/sessions/rknn/__init__.py b/machine-learning/app/sessions/rknn/__init__.py index a35ffcacab..2b72c03dec 100644 --- a/machine-learning/app/sessions/rknn/__init__.py +++ b/machine-learning/app/sessions/rknn/__init__.py @@ -12,7 +12,7 @@ from app.schemas import SessionNode from .rknnpool import RknnPoolExecutor, is_available, soc_name is_available = is_available and settings.rknn -model_prefix = Path("rknpu") / soc_name if is_available else None +model_prefix = Path("rknpu") / soc_name if is_available and soc_name is not None else None def run_inference(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArray[np.float32]]: @@ -62,10 +62,15 @@ class RknnSession: ) -> list[NDArray[np.float32]]: input_data: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()] self.rknnpool.put(input_data) - outputs: list[NDArray[np.float32]] = self.rknnpool.get() - return outputs + res = self.rknnpool.get() + if res is None: + raise RuntimeError("RKNN inference failed!") + return res class RknnNode(NamedTuple): name: str | None shape: tuple[int, ...] + + +__all__ = ["RknnSession", "RknnNode", "is_available", "soc_name", "model_prefix"] diff --git a/machine-learning/app/sessions/rknn/rknnpool.py b/machine-learning/app/sessions/rknn/rknnpool.py index 3c17cd9754..38d64ab78a 100644 --- a/machine-learning/app/sessions/rknn/rknnpool.py +++ b/machine-learning/app/sessions/rknn/rknnpool.py @@ -1,9 +1,10 @@ # This code is from leafqycc/rknn-multi-threaded # Following Apache License 2.0 -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor from pathlib import Path from queue import Queue +from typing import Callable, TypeVar import numpy as np from numpy.typing import NDArray @@ -21,7 +22,7 @@ def get_soc(device_tree_path: Path | str) -> str | None: return soc log.warning("Device is not supported for RKNN") except OSError as e: - log.warning("Could not read /proc/device-tree/compatible. Reason: %s", e.msg) + log.warning("Could not read /proc/device-tree/compatible. Reason: %s", e) return None @@ -56,19 +57,24 @@ def init_rknn(model_path: str) -> "RKNNLite": class RknnPoolExecutor: - def __init__(self, model_path: str, tpes: int, func): + def __init__( + self, + model_path: str, + tpes: int, + func: Callable[["RKNNLite", list[NDArray[np.float32]]], list[NDArray[np.float32]]], + ) -> None: self.tpes = tpes - self.queue = Queue() + self.queue: Queue[Future[list[NDArray[np.float32]]]] = Queue() self.rknn_pool = [init_rknn(model_path) for _ in range(tpes)] self.pool = ThreadPoolExecutor(max_workers=tpes) self.func = func self.num = 0 - def put(self, inputs) -> None: + def put(self, inputs: list[NDArray[np.float32]]) -> None: self.queue.put(self.pool.submit(self.func, self.rknn_pool[self.num % self.tpes], inputs)) self.num += 1 - def get(self) -> list[list[NDArray[np.float32]], bool]: + def get(self) -> list[NDArray[np.float32]] | None: if self.queue.empty(): return None fut = self.queue.get() diff --git a/machine-learning/app/test_main.py b/machine-learning/app/test_main.py index 9dc589bd06..4f0f1e7cc8 100644 --- a/machine-learning/app/test_main.py +++ b/machine-learning/app/test_main.py @@ -25,7 +25,7 @@ from app.models.facial_recognition.detection import FaceDetector from app.models.facial_recognition.recognition import FaceRecognizer from app.sessions.ann import AnnSession from app.sessions.ort import OrtSession -from app.sessions.rknn import RknnSession, runInference +from app.sessions.rknn import RknnSession, run_inference from .config import Settings, settings from .models.base import InferenceModel @@ -356,11 +356,11 @@ class TestRknnSession: RknnSession(model_path) rknn_session.assert_called_once_with( - rknnModel=Path(str(model_path).replace("model", "rk3566")).as_posix(), tpes=tpe, func=runInference + rknnModel=Path(str(model_path).replace("model", "rk3566")).as_posix(), tpes=tpe, func=run_inference ) info.assert_has_calls( - [mock.call(f"Loaded RKNN model from {str(model_path).replace('model','rk3566')} with {tpe} threads.")] + [mock.call(f"Loaded RKNN model from {str(model_path).replace('model', 'rk3566')} with {tpe} threads.")] ) def test_run_rknn(self, rknn_session: mock.Mock, mocker: MockerFixture) -> None: