import json from abc import abstractmethod from functools import cached_property from io import BytesIO from pathlib import Path from typing import Any import numpy as np from numpy.typing import NDArray from PIL import Image from app.config import log from app.models.transforms import crop_pil, get_pil_resampling, normalize, resize_pil, to_numpy from app.schemas import ModelSession, ModelTask, ModelType from app.models.base import InferenceModel class BaseCLIPVisualEncoder(InferenceModel): _model_task = ModelTask.SEARCH _model_type = ModelType.VISUAL def _predict(self, inputs: Image.Image | bytes, **kwargs: Any) -> NDArray[np.float32]: if isinstance(inputs, bytes): inputs = Image.open(BytesIO(inputs)) res: NDArray[np.float32] = self.session.run(None, self.transform(inputs))[0][0] return res @abstractmethod def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]: pass @property def model_dir(self) -> Path: return self.cache_dir / "visual" @property def model_cfg_path(self) -> Path: return self.cache_dir / "config.json" @property def model_path(self) -> Path: return self.model_dir / f"model.{self.preferred_runtime}" @property def preprocess_cfg_path(self) -> Path: return self.model_dir / "preprocess_cfg.json" @property def cached(self) -> bool: return self.model_path.is_file() @cached_property def model_cfg(self) -> dict[str, Any]: log.debug(f"Loading model config for CLIP model '{self.model_name}'") model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open()) log.debug(f"Loaded model config for CLIP model '{self.model_name}'") return model_cfg @cached_property def preprocess_cfg(self) -> dict[str, Any]: log.debug(f"Loading visual preprocessing config for CLIP model '{self.model_name}'") preprocess_cfg: dict[str, Any] = json.load(self.preprocess_cfg_path.open()) log.debug(f"Loaded visual preprocessing config for CLIP model '{self.model_name}'") return preprocess_cfg class OpenClipVisualEncoder(BaseCLIPVisualEncoder): def _load(self) -> ModelSession: size: list[int] | int = self.preprocess_cfg["size"] self.size = size[0] if isinstance(size, list) else size self.resampling = get_pil_resampling(self.preprocess_cfg["interpolation"]) self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32) self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32) return super()._load() def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]: image = resize_pil(image, self.size) image = crop_pil(image, self.size) image_np = to_numpy(image) image_np = normalize(image_np, self.mean, self.std) return {"image": np.expand_dims(image_np.transpose(2, 0, 1), 0)}