more linting

This commit is contained in:
mertalev 2025-03-14 15:30:11 -04:00
parent bd9374e4a9
commit 4db85a8954
No known key found for this signature in database
GPG Key ID: F7C271C07CF04AAE
3 changed files with 23 additions and 12 deletions

View File

@ -12,7 +12,7 @@ from app.schemas import SessionNode
from .rknnpool import RknnPoolExecutor, is_available, soc_name from .rknnpool import RknnPoolExecutor, is_available, soc_name
is_available = is_available and settings.rknn is_available = is_available and settings.rknn
model_prefix = Path("rknpu") / soc_name if is_available else None model_prefix = Path("rknpu") / soc_name if is_available and soc_name is not None else None
def run_inference(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArray[np.float32]]: def run_inference(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArray[np.float32]]:
@ -62,10 +62,15 @@ class RknnSession:
) -> list[NDArray[np.float32]]: ) -> list[NDArray[np.float32]]:
input_data: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()] input_data: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()]
self.rknnpool.put(input_data) self.rknnpool.put(input_data)
outputs: list[NDArray[np.float32]] = self.rknnpool.get() res = self.rknnpool.get()
return outputs if res is None:
raise RuntimeError("RKNN inference failed!")
return res
class RknnNode(NamedTuple): class RknnNode(NamedTuple):
name: str | None name: str | None
shape: tuple[int, ...] shape: tuple[int, ...]
__all__ = ["RknnSession", "RknnNode", "is_available", "soc_name", "model_prefix"]

View File

@ -1,9 +1,10 @@
# This code is from leafqycc/rknn-multi-threaded # This code is from leafqycc/rknn-multi-threaded
# Following Apache License 2.0 # Following Apache License 2.0
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from pathlib import Path from pathlib import Path
from queue import Queue from queue import Queue
from typing import Callable, TypeVar
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
@ -21,7 +22,7 @@ def get_soc(device_tree_path: Path | str) -> str | None:
return soc return soc
log.warning("Device is not supported for RKNN") log.warning("Device is not supported for RKNN")
except OSError as e: except OSError as e:
log.warning("Could not read /proc/device-tree/compatible. Reason: %s", e.msg) log.warning("Could not read /proc/device-tree/compatible. Reason: %s", e)
return None return None
@ -56,19 +57,24 @@ def init_rknn(model_path: str) -> "RKNNLite":
class RknnPoolExecutor: class RknnPoolExecutor:
def __init__(self, model_path: str, tpes: int, func): def __init__(
self,
model_path: str,
tpes: int,
func: Callable[["RKNNLite", list[NDArray[np.float32]]], list[NDArray[np.float32]]],
) -> None:
self.tpes = tpes self.tpes = tpes
self.queue = Queue() self.queue: Queue[Future[list[NDArray[np.float32]]]] = Queue()
self.rknn_pool = [init_rknn(model_path) for _ in range(tpes)] self.rknn_pool = [init_rknn(model_path) for _ in range(tpes)]
self.pool = ThreadPoolExecutor(max_workers=tpes) self.pool = ThreadPoolExecutor(max_workers=tpes)
self.func = func self.func = func
self.num = 0 self.num = 0
def put(self, inputs) -> None: def put(self, inputs: list[NDArray[np.float32]]) -> None:
self.queue.put(self.pool.submit(self.func, self.rknn_pool[self.num % self.tpes], inputs)) self.queue.put(self.pool.submit(self.func, self.rknn_pool[self.num % self.tpes], inputs))
self.num += 1 self.num += 1
def get(self) -> list[list[NDArray[np.float32]], bool]: def get(self) -> list[NDArray[np.float32]] | None:
if self.queue.empty(): if self.queue.empty():
return None return None
fut = self.queue.get() fut = self.queue.get()

View File

@ -25,7 +25,7 @@ from app.models.facial_recognition.detection import FaceDetector
from app.models.facial_recognition.recognition import FaceRecognizer from app.models.facial_recognition.recognition import FaceRecognizer
from app.sessions.ann import AnnSession from app.sessions.ann import AnnSession
from app.sessions.ort import OrtSession from app.sessions.ort import OrtSession
from app.sessions.rknn import RknnSession, runInference from app.sessions.rknn import RknnSession, run_inference
from .config import Settings, settings from .config import Settings, settings
from .models.base import InferenceModel from .models.base import InferenceModel
@ -356,11 +356,11 @@ class TestRknnSession:
RknnSession(model_path) RknnSession(model_path)
rknn_session.assert_called_once_with( rknn_session.assert_called_once_with(
rknnModel=Path(str(model_path).replace("model", "rk3566")).as_posix(), tpes=tpe, func=runInference rknnModel=Path(str(model_path).replace("model", "rk3566")).as_posix(), tpes=tpe, func=run_inference
) )
info.assert_has_calls( info.assert_has_calls(
[mock.call(f"Loaded RKNN model from {str(model_path).replace('model','rk3566')} with {tpe} threads.")] [mock.call(f"Loaded RKNN model from {str(model_path).replace('model', 'rk3566')} with {tpe} threads.")]
) )
def test_run_rknn(self, rknn_session: mock.Mock, mocker: MockerFixture) -> None: def test_run_rknn(self, rknn_session: mock.Mock, mocker: MockerFixture) -> None: