refactor ignore_patterns

This commit is contained in:
yoni13 2025-01-18 04:46:09 +00:00
parent 87a46dcc5e
commit d7381ab5c1

View File

@ -68,14 +68,16 @@ class InferenceModel(ABC):
pass pass
def _download(self) -> None: def _download(self) -> None:
ignore_patterns = [] if self.model_format == ModelFormat.ARMNN else ["*.armnn"] ignored_patterns: dict[ModelFormat, list[str]] = {
if self.model_format != ModelFormat.RKNN: ModelFormat.ARMNN: ["*.rknn"],
ignore_patterns.append("*.rknn") ModelFormat.RKNN: ["*.armnn"],
}
snapshot_download( snapshot_download(
f"immich-app/{clean_name(self.model_name)}", f"immich-app/{clean_name(self.model_name)}",
cache_dir=self.cache_dir, cache_dir=self.cache_dir,
local_dir=self.cache_dir, local_dir=self.cache_dir,
ignore_patterns=ignore_patterns, ignore_patterns=ignored_patterns.get(self.model_format, []),
) )
def _load(self) -> ModelSession: def _load(self) -> ModelSession: