more linting

This commit is contained in:
mertalev 2025-03-14 15:30:11 -04:00
parent bd9374e4a9
commit 4db85a8954
No known key found for this signature in database
GPG Key ID: F7C271C07CF04AAE
3 changed files with 23 additions and 12 deletions

View File

@ -12,7 +12,7 @@ from app.schemas import SessionNode
from .rknnpool import RknnPoolExecutor, is_available, soc_name
is_available = is_available and settings.rknn
model_prefix = Path("rknpu") / soc_name if is_available else None
model_prefix = Path("rknpu") / soc_name if is_available and soc_name is not None else None
def run_inference(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArray[np.float32]]:
@ -62,10 +62,15 @@ class RknnSession:
) -> list[NDArray[np.float32]]:
input_data: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()]
self.rknnpool.put(input_data)
outputs: list[NDArray[np.float32]] = self.rknnpool.get()
return outputs
res = self.rknnpool.get()
if res is None:
raise RuntimeError("RKNN inference failed!")
return res
class RknnNode(NamedTuple):
name: str | None
shape: tuple[int, ...]
__all__ = ["RknnSession", "RknnNode", "is_available", "soc_name", "model_prefix"]

View File

@ -1,9 +1,10 @@
# This code is from leafqycc/rknn-multi-threaded
# Following Apache License 2.0
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import Future, ThreadPoolExecutor
from pathlib import Path
from queue import Queue
from typing import Callable, TypeVar
import numpy as np
from numpy.typing import NDArray
@ -21,7 +22,7 @@ def get_soc(device_tree_path: Path | str) -> str | None:
return soc
log.warning("Device is not supported for RKNN")
except OSError as e:
log.warning("Could not read /proc/device-tree/compatible. Reason: %s", e.msg)
log.warning("Could not read /proc/device-tree/compatible. Reason: %s", e)
return None
@ -56,19 +57,24 @@ def init_rknn(model_path: str) -> "RKNNLite":
class RknnPoolExecutor:
def __init__(self, model_path: str, tpes: int, func):
def __init__(
self,
model_path: str,
tpes: int,
func: Callable[["RKNNLite", list[NDArray[np.float32]]], list[NDArray[np.float32]]],
) -> None:
self.tpes = tpes
self.queue = Queue()
self.queue: Queue[Future[list[NDArray[np.float32]]]] = Queue()
self.rknn_pool = [init_rknn(model_path) for _ in range(tpes)]
self.pool = ThreadPoolExecutor(max_workers=tpes)
self.func = func
self.num = 0
def put(self, inputs) -> None:
def put(self, inputs: list[NDArray[np.float32]]) -> None:
self.queue.put(self.pool.submit(self.func, self.rknn_pool[self.num % self.tpes], inputs))
self.num += 1
def get(self) -> list[list[NDArray[np.float32]], bool]:
def get(self) -> list[NDArray[np.float32]] | None:
if self.queue.empty():
return None
fut = self.queue.get()

View File

@ -25,7 +25,7 @@ from app.models.facial_recognition.detection import FaceDetector
from app.models.facial_recognition.recognition import FaceRecognizer
from app.sessions.ann import AnnSession
from app.sessions.ort import OrtSession
from app.sessions.rknn import RknnSession, runInference
from app.sessions.rknn import RknnSession, run_inference
from .config import Settings, settings
from .models.base import InferenceModel
@ -356,11 +356,11 @@ class TestRknnSession:
RknnSession(model_path)
rknn_session.assert_called_once_with(
rknnModel=Path(str(model_path).replace("model", "rk3566")).as_posix(), tpes=tpe, func=runInference
rknnModel=Path(str(model_path).replace("model", "rk3566")).as_posix(), tpes=tpe, func=run_inference
)
info.assert_has_calls(
[mock.call(f"Loaded RKNN model from {str(model_path).replace('model','rk3566')} with {tpe} threads.")]
[mock.call(f"Loaded RKNN model from {str(model_path).replace('model', 'rk3566')} with {tpe} threads.")]
)
def test_run_rknn(self, rknn_session: mock.Mock, mocker: MockerFixture) -> None: