From d7381ab5c1d7d700abb82d5b5f19b55937a56f7d Mon Sep 17 00:00:00 2001 From: yoni13 Date: Sat, 18 Jan 2025 04:46:09 +0000 Subject: [PATCH] refactor ignore_patterns --- machine-learning/app/models/base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/machine-learning/app/models/base.py b/machine-learning/app/models/base.py index 5d1843ab4b..bea3684c66 100644 --- a/machine-learning/app/models/base.py +++ b/machine-learning/app/models/base.py @@ -68,14 +68,16 @@ class InferenceModel(ABC): pass def _download(self) -> None: - ignore_patterns = [] if self.model_format == ModelFormat.ARMNN else ["*.armnn"] - if self.model_format != ModelFormat.RKNN: - ignore_patterns.append("*.rknn") + ignored_patterns: dict[ModelFormat, list[str]] = { + ModelFormat.ARMNN: ["*.rknn"], + ModelFormat.RKNN: ["*.armnn"], + } + snapshot_download( f"immich-app/{clean_name(self.model_name)}", cache_dir=self.cache_dir, local_dir=self.cache_dir, - ignore_patterns=ignore_patterns, + ignore_patterns=ignored_patterns.get(self.model_format, []), ) def _load(self) -> ModelSession: