mirror of
https://github.com/immich-app/immich.git
synced 2025-10-24 07:19:05 -04:00
* refactor: migrate person repository to kysely * `asVector` begone * linting * fix metadata faces * update test --------- Co-authored-by: Alex <alex.tran1502@gmail.com> Co-authored-by: mertalev <101130780+mertalev@users.noreply.github.com>
78 lines
2.6 KiB
Python
78 lines
2.6 KiB
Python
import json
|
|
from abc import abstractmethod
|
|
from functools import cached_property
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
from numpy.typing import NDArray
|
|
from PIL import Image
|
|
|
|
from app.config import log
|
|
from app.models.base import InferenceModel
|
|
from app.models.transforms import (
|
|
crop_pil,
|
|
decode_pil,
|
|
get_pil_resampling,
|
|
normalize,
|
|
resize_pil,
|
|
serialize_np_array,
|
|
to_numpy,
|
|
)
|
|
from app.schemas import ModelSession, ModelTask, ModelType
|
|
|
|
|
|
class BaseCLIPVisualEncoder(InferenceModel):
|
|
depends = []
|
|
identity = (ModelType.VISUAL, ModelTask.SEARCH)
|
|
|
|
def _predict(self, inputs: Image.Image | bytes, **kwargs: Any) -> str:
|
|
image = decode_pil(inputs)
|
|
res: NDArray[np.float32] = self.session.run(None, self.transform(image))[0][0]
|
|
return serialize_np_array(res)
|
|
|
|
@abstractmethod
|
|
def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]:
|
|
pass
|
|
|
|
@property
|
|
def model_cfg_path(self) -> Path:
|
|
return self.cache_dir / "config.json"
|
|
|
|
@property
|
|
def preprocess_cfg_path(self) -> Path:
|
|
return self.model_dir / "preprocess_cfg.json"
|
|
|
|
@cached_property
|
|
def model_cfg(self) -> dict[str, Any]:
|
|
log.debug(f"Loading model config for CLIP model '{self.model_name}'")
|
|
model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
|
|
log.debug(f"Loaded model config for CLIP model '{self.model_name}'")
|
|
return model_cfg
|
|
|
|
@cached_property
|
|
def preprocess_cfg(self) -> dict[str, Any]:
|
|
log.debug(f"Loading visual preprocessing config for CLIP model '{self.model_name}'")
|
|
preprocess_cfg: dict[str, Any] = json.load(self.preprocess_cfg_path.open())
|
|
log.debug(f"Loaded visual preprocessing config for CLIP model '{self.model_name}'")
|
|
return preprocess_cfg
|
|
|
|
|
|
class OpenClipVisualEncoder(BaseCLIPVisualEncoder):
|
|
def _load(self) -> ModelSession:
|
|
size: list[int] | int = self.preprocess_cfg["size"]
|
|
self.size = size[0] if isinstance(size, list) else size
|
|
|
|
self.resampling = get_pil_resampling(self.preprocess_cfg["interpolation"])
|
|
self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32)
|
|
self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32)
|
|
|
|
return super()._load()
|
|
|
|
def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]:
|
|
image = resize_pil(image, self.size)
|
|
image = crop_pil(image, self.size)
|
|
image_np = to_numpy(image)
|
|
image_np = normalize(image_np, self.mean, self.std)
|
|
return {"image": np.expand_dims(image_np.transpose(2, 0, 1), 0)}
|