forked from Cutlery/immich
		
	chore(ml): removed vit-b check and st warning (#4422)
This commit is contained in:
		
							parent
							
								
									b8d6cc1e09
								
							
						
					
					
						commit
						d8ecefaea5
					
				| @ -16,13 +16,6 @@ from ..config import log | |||||||
| from ..schemas import ModelType | from ..schemas import ModelType | ||||||
| from .base import InferenceModel | from .base import InferenceModel | ||||||
| 
 | 
 | ||||||
| _ST_TO_JINA_MODEL_NAME = { |  | ||||||
|     "clip-ViT-B-16": "ViT-B-16::openai", |  | ||||||
|     "clip-ViT-B-32": "ViT-B-32::openai", |  | ||||||
|     "clip-ViT-B-32-multilingual-v1": "M-CLIP/XLM-Roberta-Large-Vit-B-32", |  | ||||||
|     "clip-ViT-L-14": "ViT-L-14::openai", |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| class CLIPEncoder(InferenceModel): | class CLIPEncoder(InferenceModel): | ||||||
|     _model_type = ModelType.CLIP |     _model_type = ModelType.CLIP | ||||||
| @ -36,11 +29,10 @@ class CLIPEncoder(InferenceModel): | |||||||
|     ) -> None: |     ) -> None: | ||||||
|         if mode is not None and mode not in ("text", "vision"): |         if mode is not None and mode not in ("text", "vision"): | ||||||
|             raise ValueError(f"Mode must be 'text', 'vision', or omitted; got '{mode}'") |             raise ValueError(f"Mode must be 'text', 'vision', or omitted; got '{mode}'") | ||||||
|         if "vit-b" not in model_name.lower(): |         if model_name not in _MODELS: | ||||||
|             raise ValueError(f"Only ViT-B models are currently supported; got '{model_name}'") |             raise ValueError(f"Unknown model name {model_name}.") | ||||||
|         self.mode = mode |         self.mode = mode | ||||||
|         jina_model_name = self._get_jina_model_name(model_name) |         super().__init__(model_name, cache_dir, **model_kwargs) | ||||||
|         super().__init__(jina_model_name, cache_dir, **model_kwargs) |  | ||||||
| 
 | 
 | ||||||
|     def _download(self) -> None: |     def _download(self) -> None: | ||||||
|         models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name] |         models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name] | ||||||
| @ -104,20 +96,6 @@ class CLIPEncoder(InferenceModel): | |||||||
| 
 | 
 | ||||||
|         return outputs[0][0].tolist() |         return outputs[0][0].tolist() | ||||||
| 
 | 
 | ||||||
|     def _get_jina_model_name(self, model_name: str) -> str: |  | ||||||
|         if model_name in _MODELS: |  | ||||||
|             return model_name |  | ||||||
|         elif model_name in _ST_TO_JINA_MODEL_NAME: |  | ||||||
|             log.warn( |  | ||||||
|                 ( |  | ||||||
|                     f"Sentence-Transformer models like '{model_name}' are not supported." |  | ||||||
|                     f"Using '{_ST_TO_JINA_MODEL_NAME[model_name]}' instead as it is the best match for '{model_name}'." |  | ||||||
|                 ), |  | ||||||
|             ) |  | ||||||
|             return _ST_TO_JINA_MODEL_NAME[model_name] |  | ||||||
|         else: |  | ||||||
|             raise ValueError(f"Unknown model name {model_name}.") |  | ||||||
| 
 |  | ||||||
|     def _download_model(self, model_name: str, model_md5: str) -> bool: |     def _download_model(self, model_name: str, model_md5: str) -> bool: | ||||||
|         # downloading logic is adapted from clip-server's CLIPOnnxModel class |         # downloading logic is adapted from clip-server's CLIPOnnxModel class | ||||||
|         download_model( |         download_model( | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user