mirror of
				https://github.com/immich-app/immich.git
				synced 2025-10-30 18:22:37 -04:00 
			
		
		
		
	
		
			
				
	
	
		
			59 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			59 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from __future__ import annotations
 | |
| 
 | |
| from pathlib import Path
 | |
| from typing import Any, NamedTuple
 | |
| 
 | |
| import numpy as np
 | |
| from numpy.typing import NDArray
 | |
| 
 | |
| from ann.ann import Ann
 | |
| from app.schemas import SessionNode
 | |
| 
 | |
| from ..config import log, settings
 | |
| 
 | |
| 
 | |
| class AnnSession:
 | |
|     """
 | |
|     Wrapper for ANN to be drop-in replacement for ONNX session.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, model_path: Path, cache_dir: Path = settings.cache_folder) -> None:
 | |
|         self.model_path = model_path
 | |
|         self.cache_dir = cache_dir
 | |
|         self.ann = Ann(tuning_level=settings.ann_tuning_level, tuning_file=(cache_dir / "gpu-tuning.ann").as_posix())
 | |
| 
 | |
|         log.info("Loading ANN model %s ...", model_path)
 | |
|         self.model = self.ann.load(
 | |
|             model_path.as_posix(),
 | |
|             cached_network_path=model_path.with_suffix(".anncache").as_posix(),
 | |
|             fp16=settings.ann_fp16_turbo,
 | |
|         )
 | |
|         log.info("Loaded ANN model with ID %d", self.model)
 | |
| 
 | |
|     def __del__(self) -> None:
 | |
|         self.ann.unload(self.model)
 | |
|         log.info("Unloaded ANN model %d", self.model)
 | |
|         self.ann.destroy()
 | |
| 
 | |
|     def get_inputs(self) -> list[SessionNode]:
 | |
|         shapes = self.ann.input_shapes[self.model]
 | |
|         return [AnnNode(None, s) for s in shapes]
 | |
| 
 | |
|     def get_outputs(self) -> list[SessionNode]:
 | |
|         shapes = self.ann.output_shapes[self.model]
 | |
|         return [AnnNode(None, s) for s in shapes]
 | |
| 
 | |
|     def run(
 | |
|         self,
 | |
|         output_names: list[str] | None,
 | |
|         input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
 | |
|         run_options: Any = None,
 | |
|     ) -> list[NDArray[np.float32]]:
 | |
|         inputs: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()]
 | |
|         return self.ann.execute(self.model, inputs)
 | |
| 
 | |
| 
 | |
| class AnnNode(NamedTuple):
 | |
|     name: str | None
 | |
|     shape: tuple[int, ...]
 |