mirror of
https://github.com/immich-app/immich.git
synced 2025-07-09 03:04:16 -04:00
more linting
This commit is contained in:
parent
bd9374e4a9
commit
4db85a8954
@ -12,7 +12,7 @@ from app.schemas import SessionNode
|
|||||||
from .rknnpool import RknnPoolExecutor, is_available, soc_name
|
from .rknnpool import RknnPoolExecutor, is_available, soc_name
|
||||||
|
|
||||||
is_available = is_available and settings.rknn
|
is_available = is_available and settings.rknn
|
||||||
model_prefix = Path("rknpu") / soc_name if is_available else None
|
model_prefix = Path("rknpu") / soc_name if is_available and soc_name is not None else None
|
||||||
|
|
||||||
|
|
||||||
def run_inference(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArray[np.float32]]:
|
def run_inference(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArray[np.float32]]:
|
||||||
@ -62,10 +62,15 @@ class RknnSession:
|
|||||||
) -> list[NDArray[np.float32]]:
|
) -> list[NDArray[np.float32]]:
|
||||||
input_data: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()]
|
input_data: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()]
|
||||||
self.rknnpool.put(input_data)
|
self.rknnpool.put(input_data)
|
||||||
outputs: list[NDArray[np.float32]] = self.rknnpool.get()
|
res = self.rknnpool.get()
|
||||||
return outputs
|
if res is None:
|
||||||
|
raise RuntimeError("RKNN inference failed!")
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
class RknnNode(NamedTuple):
|
class RknnNode(NamedTuple):
|
||||||
name: str | None
|
name: str | None
|
||||||
shape: tuple[int, ...]
|
shape: tuple[int, ...]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["RknnSession", "RknnNode", "is_available", "soc_name", "model_prefix"]
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
# This code is from leafqycc/rknn-multi-threaded
|
# This code is from leafqycc/rknn-multi-threaded
|
||||||
# Following Apache License 2.0
|
# Following Apache License 2.0
|
||||||
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
|
from typing import Callable, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
@ -21,7 +22,7 @@ def get_soc(device_tree_path: Path | str) -> str | None:
|
|||||||
return soc
|
return soc
|
||||||
log.warning("Device is not supported for RKNN")
|
log.warning("Device is not supported for RKNN")
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
log.warning("Could not read /proc/device-tree/compatible. Reason: %s", e.msg)
|
log.warning("Could not read /proc/device-tree/compatible. Reason: %s", e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -56,19 +57,24 @@ def init_rknn(model_path: str) -> "RKNNLite":
|
|||||||
|
|
||||||
|
|
||||||
class RknnPoolExecutor:
|
class RknnPoolExecutor:
|
||||||
def __init__(self, model_path: str, tpes: int, func):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path: str,
|
||||||
|
tpes: int,
|
||||||
|
func: Callable[["RKNNLite", list[NDArray[np.float32]]], list[NDArray[np.float32]]],
|
||||||
|
) -> None:
|
||||||
self.tpes = tpes
|
self.tpes = tpes
|
||||||
self.queue = Queue()
|
self.queue: Queue[Future[list[NDArray[np.float32]]]] = Queue()
|
||||||
self.rknn_pool = [init_rknn(model_path) for _ in range(tpes)]
|
self.rknn_pool = [init_rknn(model_path) for _ in range(tpes)]
|
||||||
self.pool = ThreadPoolExecutor(max_workers=tpes)
|
self.pool = ThreadPoolExecutor(max_workers=tpes)
|
||||||
self.func = func
|
self.func = func
|
||||||
self.num = 0
|
self.num = 0
|
||||||
|
|
||||||
def put(self, inputs) -> None:
|
def put(self, inputs: list[NDArray[np.float32]]) -> None:
|
||||||
self.queue.put(self.pool.submit(self.func, self.rknn_pool[self.num % self.tpes], inputs))
|
self.queue.put(self.pool.submit(self.func, self.rknn_pool[self.num % self.tpes], inputs))
|
||||||
self.num += 1
|
self.num += 1
|
||||||
|
|
||||||
def get(self) -> list[list[NDArray[np.float32]], bool]:
|
def get(self) -> list[NDArray[np.float32]] | None:
|
||||||
if self.queue.empty():
|
if self.queue.empty():
|
||||||
return None
|
return None
|
||||||
fut = self.queue.get()
|
fut = self.queue.get()
|
||||||
|
@ -25,7 +25,7 @@ from app.models.facial_recognition.detection import FaceDetector
|
|||||||
from app.models.facial_recognition.recognition import FaceRecognizer
|
from app.models.facial_recognition.recognition import FaceRecognizer
|
||||||
from app.sessions.ann import AnnSession
|
from app.sessions.ann import AnnSession
|
||||||
from app.sessions.ort import OrtSession
|
from app.sessions.ort import OrtSession
|
||||||
from app.sessions.rknn import RknnSession, runInference
|
from app.sessions.rknn import RknnSession, run_inference
|
||||||
|
|
||||||
from .config import Settings, settings
|
from .config import Settings, settings
|
||||||
from .models.base import InferenceModel
|
from .models.base import InferenceModel
|
||||||
@ -356,11 +356,11 @@ class TestRknnSession:
|
|||||||
RknnSession(model_path)
|
RknnSession(model_path)
|
||||||
|
|
||||||
rknn_session.assert_called_once_with(
|
rknn_session.assert_called_once_with(
|
||||||
rknnModel=Path(str(model_path).replace("model", "rk3566")).as_posix(), tpes=tpe, func=runInference
|
rknnModel=Path(str(model_path).replace("model", "rk3566")).as_posix(), tpes=tpe, func=run_inference
|
||||||
)
|
)
|
||||||
|
|
||||||
info.assert_has_calls(
|
info.assert_has_calls(
|
||||||
[mock.call(f"Loaded RKNN model from {str(model_path).replace('model','rk3566')} with {tpe} threads.")]
|
[mock.call(f"Loaded RKNN model from {str(model_path).replace('model', 'rk3566')} with {tpe} threads.")]
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_run_rknn(self, rknn_session: mock.Mock, mocker: MockerFixture) -> None:
|
def test_run_rknn(self, rknn_session: mock.Mock, mocker: MockerFixture) -> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user