diff --git a/machine-learning/app/models/base.py b/machine-learning/app/models/base.py index 5d1843ab4b..bea3684c66 100644 --- a/machine-learning/app/models/base.py +++ b/machine-learning/app/models/base.py @@ -68,14 +68,16 @@ class InferenceModel(ABC): pass def _download(self) -> None: - ignore_patterns = [] if self.model_format == ModelFormat.ARMNN else ["*.armnn"] - if self.model_format != ModelFormat.RKNN: - ignore_patterns.append("*.rknn") + ignored_patterns: dict[ModelFormat, list[str]] = { + ModelFormat.ARMNN: ["*.rknn"], + ModelFormat.RKNN: ["*.armnn"], + } + snapshot_download( f"immich-app/{clean_name(self.model_name)}", cache_dir=self.cache_dir, local_dir=self.cache_dir, - ignore_patterns=ignore_patterns, + ignore_patterns=ignored_patterns.get(self.model_format, []), ) def _load(self) -> ModelSession: