fix tests

This commit is contained in:
mertalev 2025-04-19 17:06:53 -04:00
parent e2d80755c6
commit 1c4a8c3968
No known key found for this signature in database
GPG Key ID: DF6ABC77AAD98C95

View File

@ -291,7 +291,6 @@ class TestOrtSession:
assert session.sess_options.execution_mode == ort.ExecutionMode.ORT_SEQUENTIAL
assert session.sess_options.inter_op_num_threads == 1
assert session.sess_options.intra_op_num_threads == 2
assert session.sess_options.enable_cpu_mem_arena is False
def test_sets_default_sess_options_does_not_set_threads_if_non_cpu_and_default_threads(self) -> None:
session = OrtSession("ViT-B-32__openai", providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
@ -311,6 +310,8 @@ class TestOrtSession:
def test_uses_arena_if_enabled(self, mocker: MockerFixture) -> None:
mock_settings = mocker.patch("immich_ml.sessions.ort.settings", autospec=True)
mock_settings.model_inter_op_threads = 0
mock_settings.model_intra_op_threads = 0
mock_settings.model_arena = True
session = OrtSession("ViT-B-32__openai", providers=["CPUExecutionProvider"])
@ -319,6 +320,8 @@ class TestOrtSession:
def test_does_not_use_arena_if_disabled(self, mocker: MockerFixture) -> None:
mock_settings = mocker.patch("immich_ml.sessions.ort.settings", autospec=True)
mock_settings.model_inter_op_threads = 0
mock_settings.model_intra_op_threads = 0
mock_settings.model_arena = False
session = OrtSession("ViT-B-32__openai", providers=["CPUExecutionProvider"])