from __future__ import annotations from pathlib import Path from typing import Any, NamedTuple import numpy as np from numpy.typing import NDArray from immich_ml.config import log, settings from immich_ml.schemas import SessionNode from .loader import Ann class AnnSession: """ Wrapper for ANN to be drop-in replacement for ONNX session. """ def __init__(self, model_path: Path, cache_dir: Path = settings.cache_folder) -> None: self.model_path = model_path self.cache_dir = cache_dir self.ann = Ann(tuning_level=settings.ann_tuning_level, tuning_file=(cache_dir / "gpu-tuning.ann").as_posix()) log.info("Loading ANN model %s ...", model_path) self.model = self.ann.load( model_path.as_posix(), cached_network_path=model_path.with_suffix(".anncache").as_posix(), fp16=settings.ann_fp16_turbo, ) log.info("Loaded ANN model with ID %d", self.model) def __del__(self) -> None: self.ann.unload(self.model) log.info("Unloaded ANN model %d", self.model) self.ann.destroy() def get_inputs(self) -> list[SessionNode]: shapes = self.ann.input_shapes[self.model] return [AnnNode(None, s) for s in shapes] def get_outputs(self) -> list[SessionNode]: shapes = self.ann.output_shapes[self.model] return [AnnNode(None, s) for s in shapes] def run( self, output_names: list[str] | None, input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]], run_options: Any = None, ) -> list[NDArray[np.float32]]: inputs: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()] return self.ann.execute(self.model, inputs) class AnnNode(NamedTuple): name: str | None shape: tuple[int, ...]