mirror of
https://github.com/immich-app/immich.git
synced 2025-07-07 10:14:08 -04:00
add test,founds bugs, fix it tomorrow
This commit is contained in:
parent
6c4e6cb96f
commit
4b0f93cf6a
@ -136,6 +136,12 @@ def ann_session() -> Iterator[mock.Mock]:
|
||||
yield mocked
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def rknn_session() -> Iterator[mock.Mock]:
|
||||
with mock.patch("app.sessions.rknn.rknnPoolExecutor") as mocked:
|
||||
yield mocked
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def rmtree() -> Iterator[mock.Mock]:
|
||||
with mock.patch("app.models.base.rmtree", autospec=True) as mocked:
|
||||
|
@ -33,6 +33,7 @@ class RknnSession:
|
||||
|
||||
log.info(f"Loading RKNN model from {self.model_path} with {self.tpe} threads.")
|
||||
self.rknnpool = rknnPoolExecutor(rknnModel=self.model_path.as_posix(), TPEs=self.tpe, func=runInfrence)
|
||||
log.info(f"Loaded RKNN model from {self.model_path} with {self.tpe} threads.")
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.rknnpool.release()
|
||||
|
@ -24,6 +24,7 @@ from app.models.facial_recognition.detection import FaceDetector
|
||||
from app.models.facial_recognition.recognition import FaceRecognizer
|
||||
from app.sessions.ann import AnnSession
|
||||
from app.sessions.ort import OrtSession
|
||||
from app.sessions.rknn import RknnSession, runInfrence
|
||||
|
||||
from .config import Settings, settings
|
||||
from .models.base import InferenceModel
|
||||
@ -68,6 +69,14 @@ class TestBase:
|
||||
|
||||
assert encoder.model_format == ModelFormat.ARMNN
|
||||
|
||||
def test_sets_default_model_format_rknn(self, mocker: MockerFixture) -> None:
|
||||
mocker.patch.object(settings, "rknn", True)
|
||||
mocker.patch("rknn.rknnpool.is_available", False)
|
||||
|
||||
encoder = OpenClipTextualEncoder("ViT-B-32__openai")
|
||||
|
||||
assert encoder.model_format == ModelFormat.ONNX
|
||||
|
||||
def test_casts_cache_dir_string_to_path(self) -> None:
|
||||
cache_dir = "/test_cache"
|
||||
encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir=cache_dir)
|
||||
@ -124,7 +133,7 @@ class TestBase:
|
||||
"immich-app/ViT-B-32__openai",
|
||||
cache_dir=encoder.cache_dir,
|
||||
local_dir=encoder.cache_dir,
|
||||
ignore_patterns=["*.armnn"],
|
||||
ignore_patterns=["*.armnn", "*.rknn"],
|
||||
)
|
||||
|
||||
def test_download_downloads_armnn_if_preferred_format(self, snapshot_download: mock.Mock) -> None:
|
||||
@ -135,7 +144,18 @@ class TestBase:
|
||||
"immich-app/ViT-B-32__openai",
|
||||
cache_dir=encoder.cache_dir,
|
||||
local_dir=encoder.cache_dir,
|
||||
ignore_patterns=[],
|
||||
ignore_patterns=["*.rknn"],
|
||||
)
|
||||
|
||||
def test_download_downloads_rknn_if_preferred_format(self, snapshot_download: mock.Mock) -> None:
|
||||
encoder = OpenClipTextualEncoder("ViT-B-32__openai", model_format=ModelFormat.RKNN)
|
||||
encoder.download()
|
||||
|
||||
snapshot_download.assert_called_once_with(
|
||||
"immich-app/ViT-B-32__openai",
|
||||
cache_dir=encoder.cache_dir,
|
||||
local_dir=encoder.cache_dir,
|
||||
ignore_patterns=["*.armnn"],
|
||||
)
|
||||
|
||||
def test_throws_exception_if_model_path_does_not_exist(
|
||||
@ -327,6 +347,61 @@ class TestAnnSession:
|
||||
np_spy.assert_has_calls([mock.call(input1), mock.call(input2)])
|
||||
|
||||
|
||||
class TestRknnSession:
|
||||
def test_creates_rknn_session(self, rknn_session: mock.Mock, info: mock.Mock, mocker: MockerFixture) -> None:
|
||||
model_path = mock.MagicMock(spec=Path)
|
||||
tpe = 1
|
||||
mocker.patch("app.sessions.rknn.soc_name", "rk3566")
|
||||
RknnSession(model_path)
|
||||
|
||||
rknn_session.assert_called_once_with(
|
||||
rknnModel=Path(str(model_path).replace("model", "rk3566")).as_posix(), TPEs=tpe, func=runInfrence
|
||||
)
|
||||
|
||||
info.assert_has_calls(
|
||||
[mock.call(f"Loaded RKNN model from {str(model_path).replace("model","rk3566")} with {tpe} threads.")]
|
||||
)
|
||||
|
||||
# def test_get_inputs(self, rknn_session: mock.Mock, mocker: MockerFixture) -> None:
|
||||
# rknn_session.return_value.load.return_value = 123
|
||||
# rknn_session.return_value.input_shapes = {123: [(1, 3, 224, 224)]}
|
||||
# mocker.patch("app.sessions.rknn.soc_name", "rk3566")
|
||||
# session = RknnSession(Path("ViT-B-32__openai"))
|
||||
|
||||
# inputs = session.get_inputs()
|
||||
|
||||
# assert len(inputs) == 1
|
||||
# assert inputs[0].name is None
|
||||
# assert inputs[0].shape == (1, 3, 224, 224)
|
||||
|
||||
# def test_get_outputs(self, rknn_session: mock.Mock, mocker: MockerFixture) -> None:
|
||||
# rknn_session.return_value.load.return_value = 123
|
||||
# rknn_session.return_value.output_shapes = {123: [(1, 3, 224, 224)]}
|
||||
# mocker.patch("rknn.rknnpool.is_available", True)
|
||||
# mocker.patch("rknn.rknnpool.soc_name", "rk3566")
|
||||
# session = RknnSession(Path("ViT-B-32__openai"))
|
||||
|
||||
# outputs = session.get_outputs()
|
||||
|
||||
# assert len(outputs) == 1
|
||||
# assert outputs[0].name is None
|
||||
# assert outputs[0].shape == (1, 3, 224, 224)
|
||||
|
||||
def test_run(self, rknn_session: mock.Mock, mocker: MockerFixture) -> None:
|
||||
rknn_session.return_value.load.return_value = 123
|
||||
np_spy = mocker.spy(np, "ascontiguousarray")
|
||||
mocker.patch("app.sessions.rknn.soc_name", "rk3566")
|
||||
session = RknnSession(Path("ViT-B-32__openai"))
|
||||
[input1, input2] = [np.random.rand(1, 3, 224, 224).astype(np.float32) for _ in range(2)]
|
||||
input_feed = {"input.1": input1, "input.2": input2}
|
||||
|
||||
session.run(None, input_feed)
|
||||
|
||||
rknn_session.return_value.put.assert_called_once_with([input1, input2])
|
||||
np_spy.call_count == 2
|
||||
np_spy.assert_has_calls([mock.call(input1), mock.call(input2)])
|
||||
|
||||
|
||||
class TestCLIP:
|
||||
embedding = np.random.rand(512).astype(np.float32)
|
||||
cache_dir = Path("test_cache")
|
||||
|
Loading…
x
Reference in New Issue
Block a user