mirror of
				https://github.com/immich-app/immich.git
				synced 2025-11-04 03:39:37 -05:00 
			
		
		
		
	* feat(ml): ARMNN acceleration for CLIP * wrap ANN as ONNX-Session * strict typing * normalize ARMNN CLIP embedding * mutex to handle concurrent execution * make inputs contiguous * fine-grained locking; concurrent network execution --------- Co-authored-by: mertalev <101130780+mertalev@users.noreply.github.com>
		
			
				
	
	
		
			158 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			158 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import logging
 | 
						|
import os
 | 
						|
import platform
 | 
						|
import subprocess
 | 
						|
from abc import abstractmethod
 | 
						|
 | 
						|
import onnx
 | 
						|
import open_clip
 | 
						|
import torch
 | 
						|
from onnx2torch import convert
 | 
						|
from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed
 | 
						|
from tinynn.converter import TFLiteConverter
 | 
						|
 | 
						|
 | 
						|
class ExportBase(torch.nn.Module):
 | 
						|
    input_shape: tuple[int, ...]
 | 
						|
 | 
						|
    def __init__(self, device: torch.device, name: str):
 | 
						|
        super().__init__()
 | 
						|
        self.device = device
 | 
						|
        self.name = name
 | 
						|
        self.optimize = 5
 | 
						|
        self.nchw_transpose = False
 | 
						|
 | 
						|
    @abstractmethod
 | 
						|
    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor]:
 | 
						|
        pass
 | 
						|
 | 
						|
    def dummy_input(self) -> torch.FloatTensor:
 | 
						|
        return torch.rand((1, 3, 224, 224), device=self.device)
 | 
						|
 | 
						|
 | 
						|
class ArcFace(ExportBase):
 | 
						|
    input_shape = (1, 3, 112, 112)
 | 
						|
 | 
						|
    def __init__(self, onnx_model_path: str, device: torch.device):
 | 
						|
        name, _ = os.path.splitext(os.path.basename(onnx_model_path))
 | 
						|
        super().__init__(device, name)
 | 
						|
        onnx_model = onnx.load_model(onnx_model_path)
 | 
						|
        make_input_shape_fixed(onnx_model.graph, onnx_model.graph.input[0].name, self.input_shape)
 | 
						|
        fix_output_shapes(onnx_model)
 | 
						|
        self.model = convert(onnx_model).to(device)
 | 
						|
        if self.device.type == "cuda":
 | 
						|
            self.model = self.model.half()
 | 
						|
 | 
						|
    def forward(self, input_tensor: torch.Tensor) -> torch.FloatTensor:
 | 
						|
        embedding: torch.FloatTensor = self.model(
 | 
						|
            input_tensor.half() if self.device.type == "cuda" else input_tensor
 | 
						|
        ).float()
 | 
						|
        assert isinstance(embedding, torch.FloatTensor)
 | 
						|
        return embedding
 | 
						|
 | 
						|
    def dummy_input(self) -> torch.FloatTensor:
 | 
						|
        return torch.rand(self.input_shape, device=self.device)
 | 
						|
 | 
						|
 | 
						|
class RetinaFace(ExportBase):
 | 
						|
    input_shape = (1, 3, 640, 640)
 | 
						|
 | 
						|
    def __init__(self, onnx_model_path: str, device: torch.device):
 | 
						|
        name, _ = os.path.splitext(os.path.basename(onnx_model_path))
 | 
						|
        super().__init__(device, name)
 | 
						|
        self.optimize = 3
 | 
						|
        self.model = convert(onnx_model_path).eval().to(device)
 | 
						|
        if self.device.type == "cuda":
 | 
						|
            self.model = self.model.half()
 | 
						|
 | 
						|
    def forward(self, input_tensor: torch.Tensor) -> tuple[torch.FloatTensor]:
 | 
						|
        out: torch.Tensor = self.model(input_tensor.half() if self.device.type == "cuda" else input_tensor)
 | 
						|
        return tuple(o.float() for o in out)
 | 
						|
 | 
						|
    def dummy_input(self) -> torch.FloatTensor:
 | 
						|
        return torch.rand(self.input_shape, device=self.device)
 | 
						|
 | 
						|
 | 
						|
class ClipVision(ExportBase):
 | 
						|
    input_shape = (1, 3, 224, 224)
 | 
						|
 | 
						|
    def __init__(self, model_name: str, weights: str, device: torch.device):
 | 
						|
        super().__init__(device, model_name + "__" + weights)
 | 
						|
        self.model = open_clip.create_model(
 | 
						|
            model_name,
 | 
						|
            weights,
 | 
						|
            precision="fp16" if device.type == "cuda" else "fp32",
 | 
						|
            jit=False,
 | 
						|
            require_pretrained=True,
 | 
						|
            device=device,
 | 
						|
        )
 | 
						|
 | 
						|
    def forward(self, input_tensor: torch.Tensor) -> torch.FloatTensor:
 | 
						|
        embedding: torch.Tensor = self.model.encode_image(
 | 
						|
            input_tensor.half() if self.device.type == "cuda" else input_tensor,
 | 
						|
            normalize=True,
 | 
						|
        ).float()
 | 
						|
        return embedding
 | 
						|
 | 
						|
 | 
						|
def export(model: ExportBase) -> None:
 | 
						|
    model.eval()
 | 
						|
    for param in model.parameters():
 | 
						|
        param.requires_grad = False
 | 
						|
    dummy_input = model.dummy_input()
 | 
						|
    model(dummy_input)
 | 
						|
    jit = torch.jit.trace(model, dummy_input)  # type: ignore[no-untyped-call,attr-defined]
 | 
						|
    tflite_model_path = f"output/{model.name}.tflite"
 | 
						|
    os.makedirs("output", exist_ok=True)
 | 
						|
 | 
						|
    converter = TFLiteConverter(
 | 
						|
        jit,
 | 
						|
        dummy_input,
 | 
						|
        tflite_model_path,
 | 
						|
        optimize=model.optimize,
 | 
						|
        nchw_transpose=model.nchw_transpose,
 | 
						|
    )
 | 
						|
    # segfaults on ARM, must run on x86_64 / AMD64
 | 
						|
    converter.convert()
 | 
						|
 | 
						|
    armnn_model_path = f"output/{model.name}.armnn"
 | 
						|
    os.environ["LD_LIBRARY_PATH"] = "armnn"
 | 
						|
    subprocess.run(
 | 
						|
        [
 | 
						|
            "./armnnconverter",
 | 
						|
            "-f",
 | 
						|
            "tflite-binary",
 | 
						|
            "-m",
 | 
						|
            tflite_model_path,
 | 
						|
            "-i",
 | 
						|
            "input_tensor",
 | 
						|
            "-o",
 | 
						|
            "output_tensor",
 | 
						|
            "-p",
 | 
						|
            armnn_model_path,
 | 
						|
        ]
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def main() -> None:
 | 
						|
    if platform.machine() not in ("x86_64", "AMD64"):
 | 
						|
        raise RuntimeError(f"Can only run on x86_64 / AMD64, not {platform.machine()}")
 | 
						|
 | 
						|
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
						|
    if device.type != "cuda":
 | 
						|
        logging.warning(
 | 
						|
            "No CUDA available, cannot create fp16 model! proceeding to create a fp32 model (use only for testing)"
 | 
						|
        )
 | 
						|
    models = [
 | 
						|
        ClipVision("ViT-B-32", "openai", device),
 | 
						|
        ArcFace("buffalo_l_rec.onnx", device),
 | 
						|
        RetinaFace("buffalo_l_det.onnx", device),
 | 
						|
    ]
 | 
						|
    for model in models:
 | 
						|
        export(model)
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    with torch.no_grad():
 | 
						|
        main()
 |