forked from Cutlery/immich
		
	* export clip models * export to hf refactored export code * export mclip, general refactoring cleanup * updated conda deps * do transforms with pillow and numpy, add tokenization config to export, general refactoring * moved conda dockerfile, re-added poetry * minor fixes * updated link * updated tests * removed `requirements.txt` from workflow * fixed mimalloc path * removed torchvision * cleaner np typing * review suggestions * update default model name * update test
		
			
				
	
	
		
			68 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			68 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import tempfile
 | |
| import warnings
 | |
| from pathlib import Path
 | |
| 
 | |
| import torch
 | |
| from multilingual_clip.pt_multilingual_clip import MultilingualCLIP
 | |
| from transformers import AutoTokenizer
 | |
| 
 | |
| from .openclip import OpenCLIPModelConfig
 | |
| from .openclip import to_onnx as openclip_to_onnx
 | |
| from .optimize import optimize
 | |
| from .util import get_model_path
 | |
| 
 | |
| _MCLIP_TO_OPENCLIP = {
 | |
|     "M-CLIP/XLM-Roberta-Large-Vit-B-32": OpenCLIPModelConfig("ViT-B-32", "openai"),
 | |
|     "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus": OpenCLIPModelConfig("ViT-B-16-plus-240", "laion400m_e32"),
 | |
|     "M-CLIP/LABSE-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"),
 | |
|     "M-CLIP/XLM-Roberta-Large-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"),
 | |
| }
 | |
| 
 | |
| 
 | |
| def to_onnx(
 | |
|     model_name: str,
 | |
|     output_dir_visual: Path | str,
 | |
|     output_dir_textual: Path | str,
 | |
| ) -> None:
 | |
|     textual_path = get_model_path(output_dir_textual)
 | |
|     with tempfile.TemporaryDirectory() as tmpdir:
 | |
|         model = MultilingualCLIP.from_pretrained(model_name, cache_dir=tmpdir)
 | |
|         AutoTokenizer.from_pretrained(model_name).save_pretrained(output_dir_textual)
 | |
| 
 | |
|         for param in model.parameters():
 | |
|             param.requires_grad_(False)
 | |
| 
 | |
|         export_text_encoder(model, textual_path)
 | |
|         openclip_to_onnx(_MCLIP_TO_OPENCLIP[model_name], output_dir_visual)
 | |
|         optimize(textual_path)
 | |
| 
 | |
| 
 | |
| def export_text_encoder(model: MultilingualCLIP, output_path: Path | str) -> None:
 | |
|     output_path = Path(output_path)
 | |
| 
 | |
|     def forward(self: MultilingualCLIP, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
 | |
|         embs = self.transformer(input_ids, attention_mask)[0]
 | |
|         embs = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None]
 | |
|         embs = self.LinearTransformation(embs)
 | |
|         return torch.nn.functional.normalize(embs, dim=-1)
 | |
| 
 | |
|     # unfortunately need to monkeypatch for tracing to work here
 | |
|     # otherwise it hits the 2GiB protobuf serialization limit
 | |
|     MultilingualCLIP.forward = forward
 | |
| 
 | |
|     args = (torch.ones(1, 77, dtype=torch.int32), torch.ones(1, 77, dtype=torch.int32))
 | |
|     with warnings.catch_warnings():
 | |
|         warnings.simplefilter("ignore", UserWarning)
 | |
|         torch.onnx.export(
 | |
|             model,
 | |
|             args,
 | |
|             output_path.as_posix(),
 | |
|             input_names=["input_ids", "attention_mask"],
 | |
|             output_names=["text_embedding"],
 | |
|             opset_version=17,
 | |
|             dynamic_axes={
 | |
|                 "input_ids": {0: "batch_size", 1: "sequence_length"},
 | |
|                 "attention_mask": {0: "batch_size", 1: "sequence_length"},
 | |
|             },
 | |
|         )
 |