mirror of
https://github.com/immich-app/immich.git
synced 2025-07-09 03:04:16 -04:00
more linting
This commit is contained in:
parent
bd9374e4a9
commit
4db85a8954
@ -12,7 +12,7 @@ from app.schemas import SessionNode
|
||||
from .rknnpool import RknnPoolExecutor, is_available, soc_name
|
||||
|
||||
is_available = is_available and settings.rknn
|
||||
model_prefix = Path("rknpu") / soc_name if is_available else None
|
||||
model_prefix = Path("rknpu") / soc_name if is_available and soc_name is not None else None
|
||||
|
||||
|
||||
def run_inference(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArray[np.float32]]:
|
||||
@ -62,10 +62,15 @@ class RknnSession:
|
||||
) -> list[NDArray[np.float32]]:
|
||||
input_data: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()]
|
||||
self.rknnpool.put(input_data)
|
||||
outputs: list[NDArray[np.float32]] = self.rknnpool.get()
|
||||
return outputs
|
||||
res = self.rknnpool.get()
|
||||
if res is None:
|
||||
raise RuntimeError("RKNN inference failed!")
|
||||
return res
|
||||
|
||||
|
||||
class RknnNode(NamedTuple):
|
||||
name: str | None
|
||||
shape: tuple[int, ...]
|
||||
|
||||
|
||||
__all__ = ["RknnSession", "RknnNode", "is_available", "soc_name", "model_prefix"]
|
||||
|
@ -1,9 +1,10 @@
|
||||
# This code is from leafqycc/rknn-multi-threaded
|
||||
# Following Apache License 2.0
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
@ -21,7 +22,7 @@ def get_soc(device_tree_path: Path | str) -> str | None:
|
||||
return soc
|
||||
log.warning("Device is not supported for RKNN")
|
||||
except OSError as e:
|
||||
log.warning("Could not read /proc/device-tree/compatible. Reason: %s", e.msg)
|
||||
log.warning("Could not read /proc/device-tree/compatible. Reason: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
@ -56,19 +57,24 @@ def init_rknn(model_path: str) -> "RKNNLite":
|
||||
|
||||
|
||||
class RknnPoolExecutor:
|
||||
def __init__(self, model_path: str, tpes: int, func):
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
tpes: int,
|
||||
func: Callable[["RKNNLite", list[NDArray[np.float32]]], list[NDArray[np.float32]]],
|
||||
) -> None:
|
||||
self.tpes = tpes
|
||||
self.queue = Queue()
|
||||
self.queue: Queue[Future[list[NDArray[np.float32]]]] = Queue()
|
||||
self.rknn_pool = [init_rknn(model_path) for _ in range(tpes)]
|
||||
self.pool = ThreadPoolExecutor(max_workers=tpes)
|
||||
self.func = func
|
||||
self.num = 0
|
||||
|
||||
def put(self, inputs) -> None:
|
||||
def put(self, inputs: list[NDArray[np.float32]]) -> None:
|
||||
self.queue.put(self.pool.submit(self.func, self.rknn_pool[self.num % self.tpes], inputs))
|
||||
self.num += 1
|
||||
|
||||
def get(self) -> list[list[NDArray[np.float32]], bool]:
|
||||
def get(self) -> list[NDArray[np.float32]] | None:
|
||||
if self.queue.empty():
|
||||
return None
|
||||
fut = self.queue.get()
|
||||
|
@ -25,7 +25,7 @@ from app.models.facial_recognition.detection import FaceDetector
|
||||
from app.models.facial_recognition.recognition import FaceRecognizer
|
||||
from app.sessions.ann import AnnSession
|
||||
from app.sessions.ort import OrtSession
|
||||
from app.sessions.rknn import RknnSession, runInference
|
||||
from app.sessions.rknn import RknnSession, run_inference
|
||||
|
||||
from .config import Settings, settings
|
||||
from .models.base import InferenceModel
|
||||
@ -356,11 +356,11 @@ class TestRknnSession:
|
||||
RknnSession(model_path)
|
||||
|
||||
rknn_session.assert_called_once_with(
|
||||
rknnModel=Path(str(model_path).replace("model", "rk3566")).as_posix(), tpes=tpe, func=runInference
|
||||
rknnModel=Path(str(model_path).replace("model", "rk3566")).as_posix(), tpes=tpe, func=run_inference
|
||||
)
|
||||
|
||||
info.assert_has_calls(
|
||||
[mock.call(f"Loaded RKNN model from {str(model_path).replace('model','rk3566')} with {tpe} threads.")]
|
||||
[mock.call(f"Loaded RKNN model from {str(model_path).replace('model', 'rk3566')} with {tpe} threads.")]
|
||||
)
|
||||
|
||||
def test_run_rknn(self, rknn_session: mock.Mock, mocker: MockerFixture) -> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user