mirror of
				https://github.com/immich-app/immich.git
				synced 2025-10-30 18:35:00 -04:00 
			
		
		
		
	* feat(ml): ARMNN acceleration for CLIP * wrap ANN as ONNX-Session * strict typing * normalize ARMNN CLIP embedding * mutex to handle concurrent execution * make inputs contiguous * fine-grained locking; concurrent network execution --------- Co-authored-by: mertalev <101130780+mertalev@users.noreply.github.com>
		
			
				
	
	
		
			158 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			158 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import logging
 | |
| import os
 | |
| import platform
 | |
| import subprocess
 | |
| from abc import abstractmethod
 | |
| 
 | |
| import onnx
 | |
| import open_clip
 | |
| import torch
 | |
| from onnx2torch import convert
 | |
| from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed
 | |
| from tinynn.converter import TFLiteConverter
 | |
| 
 | |
| 
 | |
| class ExportBase(torch.nn.Module):
 | |
|     input_shape: tuple[int, ...]
 | |
| 
 | |
|     def __init__(self, device: torch.device, name: str):
 | |
|         super().__init__()
 | |
|         self.device = device
 | |
|         self.name = name
 | |
|         self.optimize = 5
 | |
|         self.nchw_transpose = False
 | |
| 
 | |
|     @abstractmethod
 | |
|     def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor]:
 | |
|         pass
 | |
| 
 | |
|     def dummy_input(self) -> torch.FloatTensor:
 | |
|         return torch.rand((1, 3, 224, 224), device=self.device)
 | |
| 
 | |
| 
 | |
| class ArcFace(ExportBase):
 | |
|     input_shape = (1, 3, 112, 112)
 | |
| 
 | |
|     def __init__(self, onnx_model_path: str, device: torch.device):
 | |
|         name, _ = os.path.splitext(os.path.basename(onnx_model_path))
 | |
|         super().__init__(device, name)
 | |
|         onnx_model = onnx.load_model(onnx_model_path)
 | |
|         make_input_shape_fixed(onnx_model.graph, onnx_model.graph.input[0].name, self.input_shape)
 | |
|         fix_output_shapes(onnx_model)
 | |
|         self.model = convert(onnx_model).to(device)
 | |
|         if self.device.type == "cuda":
 | |
|             self.model = self.model.half()
 | |
| 
 | |
|     def forward(self, input_tensor: torch.Tensor) -> torch.FloatTensor:
 | |
|         embedding: torch.FloatTensor = self.model(
 | |
|             input_tensor.half() if self.device.type == "cuda" else input_tensor
 | |
|         ).float()
 | |
|         assert isinstance(embedding, torch.FloatTensor)
 | |
|         return embedding
 | |
| 
 | |
|     def dummy_input(self) -> torch.FloatTensor:
 | |
|         return torch.rand(self.input_shape, device=self.device)
 | |
| 
 | |
| 
 | |
| class RetinaFace(ExportBase):
 | |
|     input_shape = (1, 3, 640, 640)
 | |
| 
 | |
|     def __init__(self, onnx_model_path: str, device: torch.device):
 | |
|         name, _ = os.path.splitext(os.path.basename(onnx_model_path))
 | |
|         super().__init__(device, name)
 | |
|         self.optimize = 3
 | |
|         self.model = convert(onnx_model_path).eval().to(device)
 | |
|         if self.device.type == "cuda":
 | |
|             self.model = self.model.half()
 | |
| 
 | |
|     def forward(self, input_tensor: torch.Tensor) -> tuple[torch.FloatTensor]:
 | |
|         out: torch.Tensor = self.model(input_tensor.half() if self.device.type == "cuda" else input_tensor)
 | |
|         return tuple(o.float() for o in out)
 | |
| 
 | |
|     def dummy_input(self) -> torch.FloatTensor:
 | |
|         return torch.rand(self.input_shape, device=self.device)
 | |
| 
 | |
| 
 | |
| class ClipVision(ExportBase):
 | |
|     input_shape = (1, 3, 224, 224)
 | |
| 
 | |
|     def __init__(self, model_name: str, weights: str, device: torch.device):
 | |
|         super().__init__(device, model_name + "__" + weights)
 | |
|         self.model = open_clip.create_model(
 | |
|             model_name,
 | |
|             weights,
 | |
|             precision="fp16" if device.type == "cuda" else "fp32",
 | |
|             jit=False,
 | |
|             require_pretrained=True,
 | |
|             device=device,
 | |
|         )
 | |
| 
 | |
|     def forward(self, input_tensor: torch.Tensor) -> torch.FloatTensor:
 | |
|         embedding: torch.Tensor = self.model.encode_image(
 | |
|             input_tensor.half() if self.device.type == "cuda" else input_tensor,
 | |
|             normalize=True,
 | |
|         ).float()
 | |
|         return embedding
 | |
| 
 | |
| 
 | |
| def export(model: ExportBase) -> None:
 | |
|     model.eval()
 | |
|     for param in model.parameters():
 | |
|         param.requires_grad = False
 | |
|     dummy_input = model.dummy_input()
 | |
|     model(dummy_input)
 | |
|     jit = torch.jit.trace(model, dummy_input)  # type: ignore[no-untyped-call,attr-defined]
 | |
|     tflite_model_path = f"output/{model.name}.tflite"
 | |
|     os.makedirs("output", exist_ok=True)
 | |
| 
 | |
|     converter = TFLiteConverter(
 | |
|         jit,
 | |
|         dummy_input,
 | |
|         tflite_model_path,
 | |
|         optimize=model.optimize,
 | |
|         nchw_transpose=model.nchw_transpose,
 | |
|     )
 | |
|     # segfaults on ARM, must run on x86_64 / AMD64
 | |
|     converter.convert()
 | |
| 
 | |
|     armnn_model_path = f"output/{model.name}.armnn"
 | |
|     os.environ["LD_LIBRARY_PATH"] = "armnn"
 | |
|     subprocess.run(
 | |
|         [
 | |
|             "./armnnconverter",
 | |
|             "-f",
 | |
|             "tflite-binary",
 | |
|             "-m",
 | |
|             tflite_model_path,
 | |
|             "-i",
 | |
|             "input_tensor",
 | |
|             "-o",
 | |
|             "output_tensor",
 | |
|             "-p",
 | |
|             armnn_model_path,
 | |
|         ]
 | |
|     )
 | |
| 
 | |
| 
 | |
| def main() -> None:
 | |
|     if platform.machine() not in ("x86_64", "AMD64"):
 | |
|         raise RuntimeError(f"Can only run on x86_64 / AMD64, not {platform.machine()}")
 | |
| 
 | |
|     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | |
|     if device.type != "cuda":
 | |
|         logging.warning(
 | |
|             "No CUDA available, cannot create fp16 model! proceeding to create a fp32 model (use only for testing)"
 | |
|         )
 | |
|     models = [
 | |
|         ClipVision("ViT-B-32", "openai", device),
 | |
|         ArcFace("buffalo_l_rec.onnx", device),
 | |
|         RetinaFace("buffalo_l_det.onnx", device),
 | |
|     ]
 | |
|     for model in models:
 | |
|         export(model)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     with torch.no_grad():
 | |
|         main()
 |