mirror of
				https://github.com/immich-app/immich.git
				synced 2025-11-04 03:27:09 -05:00 
			
		
		
		
	* chore(deps): update machine-learning * fix typing, use new lifespan syntax * wrap in try / finally * move log --------- Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: mertalev <101130780+mertalev@users.noreply.github.com>
		
			
				
	
	
		
			69 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			69 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from __future__ import annotations
 | 
						|
 | 
						|
from pathlib import Path
 | 
						|
from typing import Any, NamedTuple
 | 
						|
 | 
						|
import numpy as np
 | 
						|
from numpy.typing import NDArray
 | 
						|
 | 
						|
from ann.ann import Ann
 | 
						|
 | 
						|
from ..config import log, settings
 | 
						|
 | 
						|
 | 
						|
class AnnSession:
 | 
						|
    """
 | 
						|
    Wrapper for ANN to be drop-in replacement for ONNX session.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, model_path: Path):
 | 
						|
        tuning_file = Path(settings.cache_folder) / "gpu-tuning.ann"
 | 
						|
        with tuning_file.open(mode="a"):
 | 
						|
            # make sure tuning file exists (without clearing contents)
 | 
						|
            # once filled, the tuning file reduces the cost/time of the first
 | 
						|
            # inference after model load by 10s of seconds
 | 
						|
            pass
 | 
						|
        self.ann = Ann(tuning_level=3, tuning_file=tuning_file.as_posix())
 | 
						|
        log.info("Loading ANN model %s ...", model_path)
 | 
						|
        cache_file = model_path.with_suffix(".anncache")
 | 
						|
        save = False
 | 
						|
        if not cache_file.is_file():
 | 
						|
            save = True
 | 
						|
            with cache_file.open(mode="a"):
 | 
						|
                # create empty model cache file
 | 
						|
                pass
 | 
						|
 | 
						|
        self.model = self.ann.load(
 | 
						|
            model_path.as_posix(),
 | 
						|
            save_cached_network=save,
 | 
						|
            cached_network_path=cache_file.as_posix(),
 | 
						|
        )
 | 
						|
        log.info("Loaded ANN model with ID %d", self.model)
 | 
						|
 | 
						|
    def __del__(self) -> None:
 | 
						|
        self.ann.unload(self.model)
 | 
						|
        log.info("Unloaded ANN model %d", self.model)
 | 
						|
        self.ann.destroy()
 | 
						|
 | 
						|
    def get_inputs(self) -> list[AnnNode]:
 | 
						|
        shapes = self.ann.input_shapes[self.model]
 | 
						|
        return [AnnNode(None, s) for s in shapes]
 | 
						|
 | 
						|
    def get_outputs(self) -> list[AnnNode]:
 | 
						|
        shapes = self.ann.output_shapes[self.model]
 | 
						|
        return [AnnNode(None, s) for s in shapes]
 | 
						|
 | 
						|
    def run(
 | 
						|
        self,
 | 
						|
        output_names: list[str] | None,
 | 
						|
        input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
 | 
						|
        run_options: Any = None,
 | 
						|
    ) -> list[NDArray[np.float32]]:
 | 
						|
        inputs: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()]
 | 
						|
        return self.ann.execute(self.model, inputs)
 | 
						|
 | 
						|
 | 
						|
class AnnNode(NamedTuple):
 | 
						|
    name: str | None
 | 
						|
    shape: tuple[int, ...]
 |