This commit is contained in:
mertalev 2025-04-19 16:30:50 -04:00
parent 543bc72ae3
commit 5a3b11d603
No known key found for this signature in database
GPG Key ID: DF6ABC77AAD98C95

View File

@ -180,6 +180,7 @@ class TestOrtSession:
CUDA_EP_OUT_OF_ORDER = ["CPUExecutionProvider", "CUDAExecutionProvider"]
TRT_EP = ["TensorrtExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"]
ROCM_EP = ["ROCMExecutionProvider", "CPUExecutionProvider"]
COREML_EP = ["CoreMLExecutionProvider", "CPUExecutionProvider"]
@pytest.mark.providers(CPU_EP)
def test_sets_cpu_provider(self, providers: list[str]) -> None:
@ -225,6 +226,12 @@ class TestOrtSession:
assert session.providers == self.ROCM_EP
@pytest.mark.providers(COREML_EP)
def test_uses_coreml(self, providers: list[str]) -> None:
session = OrtSession("ViT-B-32__openai")
assert session.providers == self.COREML_EP
def test_sets_provider_kwarg(self) -> None:
providers = ["CUDAExecutionProvider"]
session = OrtSession("ViT-B-32__openai", providers=providers)