mirror of
				https://github.com/immich-app/immich.git
				synced 2025-11-04 03:27:09 -05:00 
			
		
		
		
	* export clip models * export to hf refactored export code * export mclip, general refactoring cleanup * updated conda deps * do transforms with pillow and numpy, add tokenization config to export, general refactoring * moved conda dockerfile, re-added poetry * minor fixes * updated link * updated tests * removed `requirements.txt` from workflow * fixed mimalloc path * removed torchvision * cleaner np typing * review suggestions * update default model name * update test
		
			
				
	
	
		
			68 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			68 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import tempfile
 | 
						|
import warnings
 | 
						|
from pathlib import Path
 | 
						|
 | 
						|
import torch
 | 
						|
from multilingual_clip.pt_multilingual_clip import MultilingualCLIP
 | 
						|
from transformers import AutoTokenizer
 | 
						|
 | 
						|
from .openclip import OpenCLIPModelConfig
 | 
						|
from .openclip import to_onnx as openclip_to_onnx
 | 
						|
from .optimize import optimize
 | 
						|
from .util import get_model_path
 | 
						|
 | 
						|
_MCLIP_TO_OPENCLIP = {
 | 
						|
    "M-CLIP/XLM-Roberta-Large-Vit-B-32": OpenCLIPModelConfig("ViT-B-32", "openai"),
 | 
						|
    "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus": OpenCLIPModelConfig("ViT-B-16-plus-240", "laion400m_e32"),
 | 
						|
    "M-CLIP/LABSE-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"),
 | 
						|
    "M-CLIP/XLM-Roberta-Large-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"),
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
def to_onnx(
 | 
						|
    model_name: str,
 | 
						|
    output_dir_visual: Path | str,
 | 
						|
    output_dir_textual: Path | str,
 | 
						|
) -> None:
 | 
						|
    textual_path = get_model_path(output_dir_textual)
 | 
						|
    with tempfile.TemporaryDirectory() as tmpdir:
 | 
						|
        model = MultilingualCLIP.from_pretrained(model_name, cache_dir=tmpdir)
 | 
						|
        AutoTokenizer.from_pretrained(model_name).save_pretrained(output_dir_textual)
 | 
						|
 | 
						|
        for param in model.parameters():
 | 
						|
            param.requires_grad_(False)
 | 
						|
 | 
						|
        export_text_encoder(model, textual_path)
 | 
						|
        openclip_to_onnx(_MCLIP_TO_OPENCLIP[model_name], output_dir_visual)
 | 
						|
        optimize(textual_path)
 | 
						|
 | 
						|
 | 
						|
def export_text_encoder(model: MultilingualCLIP, output_path: Path | str) -> None:
 | 
						|
    output_path = Path(output_path)
 | 
						|
 | 
						|
    def forward(self: MultilingualCLIP, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
 | 
						|
        embs = self.transformer(input_ids, attention_mask)[0]
 | 
						|
        embs = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None]
 | 
						|
        embs = self.LinearTransformation(embs)
 | 
						|
        return torch.nn.functional.normalize(embs, dim=-1)
 | 
						|
 | 
						|
    # unfortunately need to monkeypatch for tracing to work here
 | 
						|
    # otherwise it hits the 2GiB protobuf serialization limit
 | 
						|
    MultilingualCLIP.forward = forward
 | 
						|
 | 
						|
    args = (torch.ones(1, 77, dtype=torch.int32), torch.ones(1, 77, dtype=torch.int32))
 | 
						|
    with warnings.catch_warnings():
 | 
						|
        warnings.simplefilter("ignore", UserWarning)
 | 
						|
        torch.onnx.export(
 | 
						|
            model,
 | 
						|
            args,
 | 
						|
            output_path.as_posix(),
 | 
						|
            input_names=["input_ids", "attention_mask"],
 | 
						|
            output_names=["text_embedding"],
 | 
						|
            opset_version=17,
 | 
						|
            dynamic_axes={
 | 
						|
                "input_ids": {0: "batch_size", 1: "sequence_length"},
 | 
						|
                "attention_mask": {0: "batch_size", 1: "sequence_length"},
 | 
						|
            },
 | 
						|
        )
 |