mirror of
				https://github.com/immich-app/immich.git
				synced 2025-10-26 16:34:43 -04:00 
			
		
		
		
	feat(ml)!: switch image classification and CLIP models to ONNX (#3809)
This commit is contained in:
		
							parent
							
								
									8211afb726
								
							
						
					
					
						commit
						165b91b068
					
				
							
								
								
									
										1
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							| @ -171,6 +171,7 @@ jobs: | ||||
|       - name: Install dependencies | ||||
|         run: | | ||||
|           poetry install --with dev | ||||
|           poetry run pip install --no-deps -r requirements.txt | ||||
|       - name: Lint with ruff | ||||
|         run: | | ||||
|           poetry run ruff check --format=github app | ||||
|  | ||||
| @ -10,8 +10,9 @@ RUN poetry config installer.max-workers 10 && \ | ||||
| RUN python -m venv /opt/venv | ||||
| ENV VIRTUAL_ENV="/opt/venv" PATH="/opt/venv/bin:${PATH}" | ||||
| 
 | ||||
| COPY poetry.lock pyproject.toml ./ | ||||
| COPY poetry.lock pyproject.toml requirements.txt ./ | ||||
| RUN poetry install --sync --no-interaction --no-ansi --no-root --only main | ||||
| RUN pip install --no-deps -r requirements.txt | ||||
| 
 | ||||
| FROM python:3.11.4-slim-bullseye@sha256:91d194f58f50594cda71dcd2e8fdefd90e7ecc57d07823813b67c8521e565dcd | ||||
| 
 | ||||
|  | ||||
| @ -1,3 +1,4 @@ | ||||
| import os | ||||
| from pathlib import Path | ||||
| 
 | ||||
| from pydantic import BaseSettings | ||||
| @ -8,8 +9,8 @@ from .schemas import ModelType | ||||
| class Settings(BaseSettings): | ||||
|     cache_folder: str = "/cache" | ||||
|     classification_model: str = "microsoft/resnet-50" | ||||
|     clip_image_model: str = "clip-ViT-B-32" | ||||
|     clip_text_model: str = "clip-ViT-B-32" | ||||
|     clip_image_model: str = "ViT-B-32::openai" | ||||
|     clip_text_model: str = "ViT-B-32::openai" | ||||
|     facial_recognition_model: str = "buffalo_l" | ||||
|     min_tag_score: float = 0.9 | ||||
|     eager_startup: bool = False | ||||
| @ -19,14 +20,20 @@ class Settings(BaseSettings): | ||||
|     workers: int = 1 | ||||
|     min_face_score: float = 0.7 | ||||
|     test_full: bool = False | ||||
|     request_threads: int = os.cpu_count() or 4 | ||||
|     model_inter_op_threads: int = 1 | ||||
|     model_intra_op_threads: int = 2 | ||||
| 
 | ||||
|     class Config: | ||||
|         env_prefix = "MACHINE_LEARNING_" | ||||
|         case_sensitive = False | ||||
| 
 | ||||
| 
 | ||||
| _clean_name = str.maketrans(":\\/", "___", ".") | ||||
| 
 | ||||
| 
 | ||||
| def get_cache_dir(model_name: str, model_type: ModelType) -> Path: | ||||
|     return Path(settings.cache_folder, model_type.value, model_name) | ||||
|     return Path(settings.cache_folder) / model_type.value / model_name.translate(_clean_name) | ||||
| 
 | ||||
| 
 | ||||
| settings = Settings() | ||||
|  | ||||
| @ -1,4 +1,6 @@ | ||||
| import asyncio | ||||
| import os | ||||
| from concurrent.futures import ThreadPoolExecutor | ||||
| from io import BytesIO | ||||
| from typing import Any | ||||
| 
 | ||||
| @ -8,6 +10,8 @@ import uvicorn | ||||
| from fastapi import Body, Depends, FastAPI | ||||
| from PIL import Image | ||||
| 
 | ||||
| from app.models.base import InferenceModel | ||||
| 
 | ||||
| from .config import settings | ||||
| from .models.cache import ModelCache | ||||
| from .schemas import ( | ||||
| @ -25,19 +29,21 @@ app = FastAPI() | ||||
| 
 | ||||
| def init_state() -> None: | ||||
|     app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0) | ||||
|     # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code | ||||
|     app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) | ||||
| 
 | ||||
| 
 | ||||
| async def load_models() -> None: | ||||
|     models = [ | ||||
|         (settings.classification_model, ModelType.IMAGE_CLASSIFICATION), | ||||
|         (settings.clip_image_model, ModelType.CLIP), | ||||
|         (settings.clip_text_model, ModelType.CLIP), | ||||
|         (settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION), | ||||
|     models: list[tuple[str, ModelType, dict[str, Any]]] = [ | ||||
|         (settings.classification_model, ModelType.IMAGE_CLASSIFICATION, {}), | ||||
|         (settings.clip_image_model, ModelType.CLIP, {"mode": "vision"}), | ||||
|         (settings.clip_text_model, ModelType.CLIP, {"mode": "text"}), | ||||
|         (settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION, {}), | ||||
|     ] | ||||
| 
 | ||||
|     # Get all models | ||||
|     for model_name, model_type in models: | ||||
|         await app.state.model_cache.get(model_name, model_type, eager=settings.eager_startup) | ||||
|     for model_name, model_type, model_kwargs in models: | ||||
|         await app.state.model_cache.get(model_name, model_type, eager=settings.eager_startup, **model_kwargs) | ||||
| 
 | ||||
| 
 | ||||
| @app.on_event("startup") | ||||
| @ -46,11 +52,16 @@ async def startup_event() -> None: | ||||
|     await load_models() | ||||
| 
 | ||||
| 
 | ||||
| @app.on_event("shutdown") | ||||
| async def shutdown_event() -> None: | ||||
|     app.state.thread_pool.shutdown() | ||||
| 
 | ||||
| 
 | ||||
| def dep_pil_image(byte_image: bytes = Body(...)) -> Image.Image: | ||||
|     return Image.open(BytesIO(byte_image)) | ||||
| 
 | ||||
| 
 | ||||
| def dep_cv_image(byte_image: bytes = Body(...)) -> cv2.Mat: | ||||
| def dep_cv_image(byte_image: bytes = Body(...)) -> np.ndarray[int, np.dtype[Any]]: | ||||
|     byte_image_np = np.frombuffer(byte_image, np.uint8) | ||||
|     return cv2.imdecode(byte_image_np, cv2.IMREAD_COLOR) | ||||
| 
 | ||||
| @ -74,7 +85,7 @@ async def image_classification( | ||||
|     image: Image.Image = Depends(dep_pil_image), | ||||
| ) -> list[str]: | ||||
|     model = await app.state.model_cache.get(settings.classification_model, ModelType.IMAGE_CLASSIFICATION) | ||||
|     labels = model.predict(image) | ||||
|     labels = await predict(model, image) | ||||
|     return labels | ||||
| 
 | ||||
| 
 | ||||
| @ -86,8 +97,8 @@ async def image_classification( | ||||
| async def clip_encode_image( | ||||
|     image: Image.Image = Depends(dep_pil_image), | ||||
| ) -> list[float]: | ||||
|     model = await app.state.model_cache.get(settings.clip_image_model, ModelType.CLIP) | ||||
|     embedding = model.predict(image) | ||||
|     model = await app.state.model_cache.get(settings.clip_image_model, ModelType.CLIP, mode="vision") | ||||
|     embedding = await predict(model, image) | ||||
|     return embedding | ||||
| 
 | ||||
| 
 | ||||
| @ -97,8 +108,8 @@ async def clip_encode_image( | ||||
|     status_code=200, | ||||
| ) | ||||
| async def clip_encode_text(payload: TextModelRequest) -> list[float]: | ||||
|     model = await app.state.model_cache.get(settings.clip_text_model, ModelType.CLIP) | ||||
|     embedding = model.predict(payload.text) | ||||
|     model = await app.state.model_cache.get(settings.clip_text_model, ModelType.CLIP, mode="text") | ||||
|     embedding = await predict(model, payload.text) | ||||
|     return embedding | ||||
| 
 | ||||
| 
 | ||||
| @ -111,10 +122,14 @@ async def facial_recognition( | ||||
|     image: cv2.Mat = Depends(dep_cv_image), | ||||
| ) -> list[dict[str, Any]]: | ||||
|     model = await app.state.model_cache.get(settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION) | ||||
|     faces = model.predict(image) | ||||
|     faces = await predict(model, image) | ||||
|     return faces | ||||
| 
 | ||||
| 
 | ||||
| async def predict(model: InferenceModel, inputs: Any) -> Any: | ||||
|     return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     is_dev = os.getenv("NODE_ENV") == "development" | ||||
|     uvicorn.run( | ||||
|  | ||||
| @ -1,3 +1,3 @@ | ||||
| from .clip import CLIPSTEncoder | ||||
| from .clip import CLIPEncoder | ||||
| from .facial_recognition import FaceRecognizer | ||||
| from .image_classification import ImageClassifier | ||||
|  | ||||
| @ -1,14 +1,17 @@ | ||||
| from __future__ import annotations | ||||
| 
 | ||||
| import os | ||||
| import pickle | ||||
| from abc import ABC, abstractmethod | ||||
| from pathlib import Path | ||||
| from shutil import rmtree | ||||
| from typing import Any | ||||
| from zipfile import BadZipFile | ||||
| 
 | ||||
| import onnxruntime as ort | ||||
| from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf  # type: ignore | ||||
| 
 | ||||
| from ..config import get_cache_dir | ||||
| from ..config import get_cache_dir, settings | ||||
| from ..schemas import ModelType | ||||
| 
 | ||||
| 
 | ||||
| @ -16,12 +19,31 @@ class InferenceModel(ABC): | ||||
|     _model_type: ModelType | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, model_name: str, cache_dir: Path | str | None = None, eager: bool = True, **model_kwargs: Any | ||||
|         self, | ||||
|         model_name: str, | ||||
|         cache_dir: Path | str | None = None, | ||||
|         eager: bool = True, | ||||
|         inter_op_num_threads: int = settings.model_inter_op_threads, | ||||
|         intra_op_num_threads: int = settings.model_intra_op_threads, | ||||
|         **model_kwargs: Any, | ||||
|     ) -> None: | ||||
|         self.model_name = model_name | ||||
|         self._loaded = False | ||||
|         self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type) | ||||
|         loader = self.load if eager else self.download | ||||
| 
 | ||||
|         self.providers = model_kwargs.pop("providers", ["CPUExecutionProvider"]) | ||||
|         #  don't pre-allocate more memory than needed | ||||
|         self.provider_options = model_kwargs.pop( | ||||
|             "provider_options", [{"arena_extend_strategy": "kSameAsRequested"}] * len(self.providers) | ||||
|         ) | ||||
|         self.sess_options = PicklableSessionOptions() | ||||
|         # avoid thread contention between models | ||||
|         if inter_op_num_threads > 1: | ||||
|             self.sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL | ||||
|         self.sess_options.inter_op_num_threads = inter_op_num_threads | ||||
|         self.sess_options.intra_op_num_threads = intra_op_num_threads | ||||
| 
 | ||||
|         try: | ||||
|             loader(**model_kwargs) | ||||
|         except (OSError, InvalidProtobuf, BadZipFile): | ||||
| @ -30,6 +52,7 @@ class InferenceModel(ABC): | ||||
| 
 | ||||
|     def download(self, **model_kwargs: Any) -> None: | ||||
|         if not self.cached: | ||||
|             print(f"Downloading {self.model_type.value.replace('_', ' ')} model. This may take a while...") | ||||
|             self._download(**model_kwargs) | ||||
| 
 | ||||
|     def load(self, **model_kwargs: Any) -> None: | ||||
| @ -39,6 +62,7 @@ class InferenceModel(ABC): | ||||
| 
 | ||||
|     def predict(self, inputs: Any) -> Any: | ||||
|         if not self._loaded: | ||||
|             print(f"Loading {self.model_type.value.replace('_', ' ')} model...") | ||||
|             self.load() | ||||
|         return self._predict(inputs) | ||||
| 
 | ||||
| @ -89,3 +113,14 @@ class InferenceModel(ABC): | ||||
|         else: | ||||
|             self.cache_dir.unlink() | ||||
|         self.cache_dir.mkdir(parents=True, exist_ok=True) | ||||
| 
 | ||||
| 
 | ||||
| # HF deep copies configs, so we need to make session options picklable | ||||
| class PicklableSessionOptions(ort.SessionOptions): | ||||
|     def __getstate__(self) -> bytes: | ||||
|         return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))]) | ||||
| 
 | ||||
|     def __setstate__(self, state: Any) -> None: | ||||
|         self.__init__()  # type: ignore | ||||
|         for attr, val in pickle.loads(state): | ||||
|             setattr(self, attr, val) | ||||
|  | ||||
| @ -46,7 +46,7 @@ class ModelCache: | ||||
|             model: The requested model. | ||||
|         """ | ||||
| 
 | ||||
|         key = self.cache.build_key(model_name, model_type.value) | ||||
|         key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}" | ||||
|         async with OptimisticLock(self.cache, key) as lock: | ||||
|             model = await self.cache.get(key) | ||||
|             if model is None: | ||||
|  | ||||
| @ -1,31 +1,141 @@ | ||||
| from typing import Any | ||||
| import os | ||||
| import zipfile | ||||
| from typing import Any, Literal | ||||
| 
 | ||||
| import onnxruntime as ort | ||||
| import torch | ||||
| from clip_server.model.clip import BICUBIC, _convert_image_to_rgb | ||||
| from clip_server.model.clip_onnx import _MODELS, _S3_BUCKET_V2, CLIPOnnxModel, download_model | ||||
| from clip_server.model.pretrained_models import _VISUAL_MODEL_IMAGE_SIZE | ||||
| from clip_server.model.tokenization import Tokenizer | ||||
| from PIL.Image import Image | ||||
| from sentence_transformers import SentenceTransformer | ||||
| from sentence_transformers.util import snapshot_download | ||||
| from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor | ||||
| 
 | ||||
| from ..schemas import ModelType | ||||
| from .base import InferenceModel | ||||
| 
 | ||||
| _ST_TO_JINA_MODEL_NAME = { | ||||
|     "clip-ViT-B-16": "ViT-B-16::openai", | ||||
|     "clip-ViT-B-32": "ViT-B-32::openai", | ||||
|     "clip-ViT-B-32-multilingual-v1": "M-CLIP/XLM-Roberta-Large-Vit-B-32", | ||||
|     "clip-ViT-L-14": "ViT-L-14::openai", | ||||
| } | ||||
| 
 | ||||
| class CLIPSTEncoder(InferenceModel): | ||||
| 
 | ||||
| class CLIPEncoder(InferenceModel): | ||||
|     _model_type = ModelType.CLIP | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         model_name: str, | ||||
|         cache_dir: str | None = None, | ||||
|         mode: Literal["text", "vision"] | None = None, | ||||
|         **model_kwargs: Any, | ||||
|     ) -> None: | ||||
|         if mode is not None and mode not in ("text", "vision"): | ||||
|             raise ValueError(f"Mode must be 'text', 'vision', or omitted; got '{mode}'") | ||||
|         if "vit-b" not in model_name.lower(): | ||||
|             raise ValueError(f"Only ViT-B models are currently supported; got '{model_name}'") | ||||
|         self.mode = mode | ||||
|         jina_model_name = self._get_jina_model_name(model_name) | ||||
|         super().__init__(jina_model_name, cache_dir, **model_kwargs) | ||||
| 
 | ||||
|     def _download(self, **model_kwargs: Any) -> None: | ||||
|         repo_id = self.model_name if "/" in self.model_name else f"sentence-transformers/{self.model_name}" | ||||
|         snapshot_download( | ||||
|             cache_dir=self.cache_dir, | ||||
|             repo_id=repo_id, | ||||
|             library_name="sentence-transformers", | ||||
|             ignore_files=["flax_model.msgpack", "rust_model.ot", "tf_model.h5"], | ||||
|         ) | ||||
|         models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name] | ||||
|         text_onnx_path = self.cache_dir / "textual.onnx" | ||||
|         vision_onnx_path = self.cache_dir / "visual.onnx" | ||||
| 
 | ||||
|         if not text_onnx_path.is_file(): | ||||
|             self._download_model(*models[0]) | ||||
| 
 | ||||
|         if not vision_onnx_path.is_file(): | ||||
|             self._download_model(*models[1]) | ||||
| 
 | ||||
|     def _load(self, **model_kwargs: Any) -> None: | ||||
|         self.model = SentenceTransformer( | ||||
|             self.model_name, | ||||
|             cache_folder=self.cache_dir.as_posix(), | ||||
|             **model_kwargs, | ||||
|         ) | ||||
|         if self.mode == "text" or self.mode is None: | ||||
|             self.text_model = ort.InferenceSession( | ||||
|                 self.cache_dir / "textual.onnx", | ||||
|                 sess_options=self.sess_options, | ||||
|                 providers=self.providers, | ||||
|                 provider_options=self.provider_options, | ||||
|             ) | ||||
|             self.text_outputs = [output.name for output in self.text_model.get_outputs()] | ||||
|             self.tokenizer = Tokenizer(self.model_name) | ||||
| 
 | ||||
|         if self.mode == "vision" or self.mode is None: | ||||
|             self.vision_model = ort.InferenceSession( | ||||
|                 self.cache_dir / "visual.onnx", | ||||
|                 sess_options=self.sess_options, | ||||
|                 providers=self.providers, | ||||
|                 provider_options=self.provider_options, | ||||
|             ) | ||||
|             self.vision_outputs = [output.name for output in self.vision_model.get_outputs()] | ||||
| 
 | ||||
|             image_size = _VISUAL_MODEL_IMAGE_SIZE[CLIPOnnxModel.get_model_name(self.model_name)] | ||||
|             self.transform = _transform_pil_image(image_size) | ||||
| 
 | ||||
|     def _predict(self, image_or_text: Image | str) -> list[float]: | ||||
|         return self.model.encode(image_or_text).tolist() | ||||
|         match image_or_text: | ||||
|             case Image(): | ||||
|                 if self.mode == "text": | ||||
|                     raise TypeError("Cannot encode image as text-only model") | ||||
|                 pixel_values = self.transform(image_or_text) | ||||
|                 assert isinstance(pixel_values, torch.Tensor) | ||||
|                 pixel_values = torch.unsqueeze(pixel_values, 0).numpy() | ||||
|                 outputs = self.vision_model.run(self.vision_outputs, {"pixel_values": pixel_values}) | ||||
|             case str(): | ||||
|                 if self.mode == "vision": | ||||
|                     raise TypeError("Cannot encode text as vision-only model") | ||||
|                 text_inputs: dict[str, torch.Tensor] = self.tokenizer(image_or_text) | ||||
|                 inputs = { | ||||
|                     "input_ids": text_inputs["input_ids"].int().numpy(), | ||||
|                     "attention_mask": text_inputs["attention_mask"].int().numpy(), | ||||
|                 } | ||||
|                 outputs = self.text_model.run(self.text_outputs, inputs) | ||||
|             case _: | ||||
|                 raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}") | ||||
| 
 | ||||
|         return outputs[0][0].tolist() | ||||
| 
 | ||||
|     def _get_jina_model_name(self, model_name: str) -> str: | ||||
|         if model_name in _MODELS: | ||||
|             return model_name | ||||
|         elif model_name in _ST_TO_JINA_MODEL_NAME: | ||||
|             print( | ||||
|                 (f"Warning: Sentence-Transformer model names such as '{model_name}' are no longer supported."), | ||||
|                 (f"Using '{_ST_TO_JINA_MODEL_NAME[model_name]}' instead as it is the best match for '{model_name}'."), | ||||
|             ) | ||||
|             return _ST_TO_JINA_MODEL_NAME[model_name] | ||||
|         else: | ||||
|             raise ValueError(f"Unknown model name {model_name}.") | ||||
| 
 | ||||
|     def _download_model(self, model_name: str, model_md5: str) -> bool: | ||||
|         # downloading logic is adapted from clip-server's CLIPOnnxModel class | ||||
|         download_model( | ||||
|             url=_S3_BUCKET_V2 + model_name, | ||||
|             target_folder=self.cache_dir.as_posix(), | ||||
|             md5sum=model_md5, | ||||
|             with_resume=True, | ||||
|         ) | ||||
|         file = self.cache_dir / model_name.split("/")[1] | ||||
|         if file.suffix == ".zip": | ||||
|             with zipfile.ZipFile(file, "r") as zip_ref: | ||||
|                 zip_ref.extractall(self.cache_dir) | ||||
|             os.remove(file) | ||||
|         return True | ||||
| 
 | ||||
| 
 | ||||
| # same as `_transform_blob` without `_blob2image` | ||||
| def _transform_pil_image(n_px: int) -> Compose: | ||||
|     return Compose( | ||||
|         [ | ||||
|             Resize(n_px, interpolation=BICUBIC), | ||||
|             CenterCrop(n_px), | ||||
|             _convert_image_to_rgb, | ||||
|             ToTensor(), | ||||
|             Normalize( | ||||
|                 (0.48145466, 0.4578275, 0.40821073), | ||||
|                 (0.26862954, 0.26130258, 0.27577711), | ||||
|             ), | ||||
|         ] | ||||
|     ) | ||||
|  | ||||
| @ -4,6 +4,7 @@ from typing import Any | ||||
| 
 | ||||
| import cv2 | ||||
| import numpy as np | ||||
| import onnxruntime as ort | ||||
| from insightface.model_zoo import ArcFaceONNX, RetinaFace | ||||
| from insightface.utils.face_align import norm_crop | ||||
| from insightface.utils.storage import BASE_REPO_URL, download_file | ||||
| @ -42,15 +43,31 @@ class FaceRecognizer(InferenceModel): | ||||
|             rec_file = next(self.cache_dir.glob("w600k_*.onnx")) | ||||
|         except StopIteration: | ||||
|             raise FileNotFoundError("Facial recognition models not found in cache directory") | ||||
|         self.det_model = RetinaFace(det_file.as_posix()) | ||||
|         self.rec_model = ArcFaceONNX(rec_file.as_posix()) | ||||
| 
 | ||||
|         self.det_model = RetinaFace( | ||||
|             session=ort.InferenceSession( | ||||
|                 det_file.as_posix(), | ||||
|                 sess_options=self.sess_options, | ||||
|                 providers=self.providers, | ||||
|                 provider_options=self.provider_options, | ||||
|             ), | ||||
|         ) | ||||
|         self.rec_model = ArcFaceONNX( | ||||
|             rec_file.as_posix(), | ||||
|             session=ort.InferenceSession( | ||||
|                 rec_file.as_posix(), | ||||
|                 sess_options=self.sess_options, | ||||
|                 providers=self.providers, | ||||
|                 provider_options=self.provider_options, | ||||
|             ), | ||||
|         ) | ||||
| 
 | ||||
|         self.det_model.prepare( | ||||
|             ctx_id=-1, | ||||
|             ctx_id=0, | ||||
|             det_thresh=self.min_score, | ||||
|             input_size=(640, 640), | ||||
|         ) | ||||
|         self.rec_model.prepare(ctx_id=-1) | ||||
|         self.rec_model.prepare(ctx_id=0) | ||||
| 
 | ||||
|     def _predict(self, image: cv2.Mat) -> list[dict[str, Any]]: | ||||
|         bboxes, kpss = self.det_model.detect(image) | ||||
|  | ||||
| @ -2,8 +2,10 @@ from pathlib import Path | ||||
| from typing import Any | ||||
| 
 | ||||
| from huggingface_hub import snapshot_download | ||||
| from optimum.onnxruntime import ORTModelForImageClassification | ||||
| from optimum.pipelines import pipeline | ||||
| from PIL.Image import Image | ||||
| from transformers.pipelines import pipeline | ||||
| from transformers import AutoImageProcessor | ||||
| 
 | ||||
| from ..config import settings | ||||
| from ..schemas import ModelType | ||||
| @ -25,15 +27,34 @@ class ImageClassifier(InferenceModel): | ||||
| 
 | ||||
|     def _download(self, **model_kwargs: Any) -> None: | ||||
|         snapshot_download( | ||||
|             cache_dir=self.cache_dir, repo_id=self.model_name, allow_patterns=["*.bin", "*.json", "*.txt"] | ||||
|             cache_dir=self.cache_dir, | ||||
|             repo_id=self.model_name, | ||||
|             allow_patterns=["*.bin", "*.json", "*.txt"], | ||||
|             local_dir=self.cache_dir, | ||||
|             local_dir_use_symlinks=True, | ||||
|         ) | ||||
| 
 | ||||
|     def _load(self, **model_kwargs: Any) -> None: | ||||
|         self.model = pipeline( | ||||
|             self.model_type.value, | ||||
|             self.model_name, | ||||
|             model_kwargs={"cache_dir": self.cache_dir, **model_kwargs}, | ||||
|         ) | ||||
|         processor = AutoImageProcessor.from_pretrained(self.cache_dir) | ||||
|         model_kwargs |= { | ||||
|             "cache_dir": self.cache_dir, | ||||
|             "provider": self.providers[0], | ||||
|             "provider_options": self.provider_options[0], | ||||
|             "session_options": self.sess_options, | ||||
|         } | ||||
|         model_path = self.cache_dir / "model.onnx" | ||||
| 
 | ||||
|         if model_path.exists(): | ||||
|             model = ORTModelForImageClassification.from_pretrained(self.cache_dir, **model_kwargs) | ||||
|             self.model = pipeline(self.model_type.value, model, feature_extractor=processor) | ||||
|         else: | ||||
|             self.sess_options.optimized_model_filepath = model_path.as_posix() | ||||
|             self.model = pipeline( | ||||
|                 self.model_type.value, | ||||
|                 self.model_name, | ||||
|                 model_kwargs=model_kwargs, | ||||
|                 feature_extractor=processor, | ||||
|             ) | ||||
| 
 | ||||
|     def _predict(self, image: Image) -> list[str]: | ||||
|         predictions: list[dict[str, Any]] = self.model(image)  # type: ignore | ||||
|  | ||||
| @ -1,17 +1,20 @@ | ||||
| import pickle | ||||
| from io import BytesIO | ||||
| from typing import TypeAlias | ||||
| from unittest import mock | ||||
| 
 | ||||
| import cv2 | ||||
| import numpy as np | ||||
| import onnxruntime as ort | ||||
| import pytest | ||||
| from fastapi.testclient import TestClient | ||||
| from PIL import Image | ||||
| from pytest_mock import MockerFixture | ||||
| 
 | ||||
| from .config import settings | ||||
| from .models.base import PicklableSessionOptions | ||||
| from .models.cache import ModelCache | ||||
| from .models.clip import CLIPSTEncoder | ||||
| from .models.clip import CLIPEncoder | ||||
| from .models.facial_recognition import FaceRecognizer | ||||
| from .models.image_classification import ImageClassifier | ||||
| from .schemas import ModelType | ||||
| @ -72,45 +75,47 @@ class TestCLIP: | ||||
|     embedding = np.random.rand(512).astype(np.float32) | ||||
| 
 | ||||
|     def test_eager_init(self, mocker: MockerFixture) -> None: | ||||
|         mocker.patch.object(CLIPSTEncoder, "download") | ||||
|         mock_load = mocker.patch.object(CLIPSTEncoder, "load") | ||||
|         clip_model = CLIPSTEncoder("test_model_name", cache_dir="test_cache", eager=True, test_arg="test_arg") | ||||
|         mocker.patch.object(CLIPEncoder, "download") | ||||
|         mock_load = mocker.patch.object(CLIPEncoder, "load") | ||||
|         clip_model = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", eager=True, test_arg="test_arg") | ||||
| 
 | ||||
|         assert clip_model.model_name == "test_model_name" | ||||
|         assert clip_model.model_name == "ViT-B-32::openai" | ||||
|         mock_load.assert_called_once_with(test_arg="test_arg") | ||||
| 
 | ||||
|     def test_lazy_init(self, mocker: MockerFixture) -> None: | ||||
|         mock_download = mocker.patch.object(CLIPSTEncoder, "download") | ||||
|         mock_load = mocker.patch.object(CLIPSTEncoder, "load") | ||||
|         clip_model = CLIPSTEncoder("test_model_name", cache_dir="test_cache", eager=False, test_arg="test_arg") | ||||
|         mock_download = mocker.patch.object(CLIPEncoder, "download") | ||||
|         mock_load = mocker.patch.object(CLIPEncoder, "load") | ||||
|         clip_model = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", eager=False, test_arg="test_arg") | ||||
| 
 | ||||
|         assert clip_model.model_name == "test_model_name" | ||||
|         assert clip_model.model_name == "ViT-B-32::openai" | ||||
|         mock_download.assert_called_once_with(test_arg="test_arg") | ||||
|         mock_load.assert_not_called() | ||||
| 
 | ||||
|     def test_basic_image(self, pil_image: Image.Image, mocker: MockerFixture) -> None: | ||||
|         mocker.patch.object(CLIPSTEncoder, "load") | ||||
|         clip_encoder = CLIPSTEncoder("test_model_name", cache_dir="test_cache") | ||||
|         clip_encoder.model = mock.Mock() | ||||
|         clip_encoder.model.encode.return_value = self.embedding | ||||
|         mocker.patch.object(CLIPEncoder, "download") | ||||
|         mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True) | ||||
|         mocked.return_value.run.return_value = [[self.embedding]] | ||||
|         clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="vision") | ||||
|         assert clip_encoder.mode == "vision" | ||||
|         embedding = clip_encoder.predict(pil_image) | ||||
| 
 | ||||
|         assert isinstance(embedding, list) | ||||
|         assert len(embedding) == 512 | ||||
|         assert all([isinstance(num, float) for num in embedding]) | ||||
|         clip_encoder.model.encode.assert_called_once() | ||||
|         clip_encoder.vision_model.run.assert_called_once() | ||||
| 
 | ||||
|     def test_basic_text(self, mocker: MockerFixture) -> None: | ||||
|         mocker.patch.object(CLIPSTEncoder, "load") | ||||
|         clip_encoder = CLIPSTEncoder("test_model_name", cache_dir="test_cache") | ||||
|         clip_encoder.model = mock.Mock() | ||||
|         clip_encoder.model.encode.return_value = self.embedding | ||||
|         mocker.patch.object(CLIPEncoder, "download") | ||||
|         mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True) | ||||
|         mocked.return_value.run.return_value = [[self.embedding]] | ||||
|         clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="text") | ||||
|         assert clip_encoder.mode == "text" | ||||
|         embedding = clip_encoder.predict("test search query") | ||||
| 
 | ||||
|         assert isinstance(embedding, list) | ||||
|         assert len(embedding) == 512 | ||||
|         assert all([isinstance(num, float) for num in embedding]) | ||||
|         clip_encoder.model.encode.assert_called_once() | ||||
|         clip_encoder.text_model.run.assert_called_once() | ||||
| 
 | ||||
| 
 | ||||
| class TestFaceRecognition: | ||||
| @ -254,3 +259,13 @@ class TestEndpoints: | ||||
|             headers=headers, | ||||
|         ) | ||||
|         assert response.status_code == 200 | ||||
| 
 | ||||
| 
 | ||||
| def test_sess_options() -> None: | ||||
|     sess_options = PicklableSessionOptions() | ||||
|     sess_options.intra_op_num_threads = 1 | ||||
|     sess_options.inter_op_num_threads = 1 | ||||
|     pickled = pickle.dumps(sess_options) | ||||
|     unpickled = pickle.loads(pickled) | ||||
|     assert unpickled.intra_op_num_threads == 1 | ||||
|     assert unpickled.inter_op_num_threads == 1 | ||||
|  | ||||
							
								
								
									
										1739
									
								
								machine-learning/poetry.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										1739
									
								
								machine-learning/poetry.lock
									
									
									
										generated
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -13,7 +13,6 @@ torch = [ | ||||
|     {markers = "platform_machine == 'amd64' or platform_machine == 'x86_64'", version = "=2.0.1", source = "pytorch-cpu"} | ||||
| ] | ||||
| transformers = "^4.29.2" | ||||
| sentence-transformers = "^2.2.2" | ||||
| onnxruntime = "^1.15.0" | ||||
| insightface = "^0.7.3" | ||||
| opencv-python-headless = "^4.7.0.72" | ||||
| @ -22,6 +21,15 @@ fastapi = "^0.95.2" | ||||
| uvicorn = {extras = ["standard"], version = "^0.22.0"} | ||||
| pydantic = "^1.10.8" | ||||
| aiocache = "^0.12.1" | ||||
| optimum = "^1.9.1" | ||||
| torchvision = [ | ||||
|     {markers = "platform_machine == 'arm64' or platform_machine == 'aarch64'", version = "=0.15.2", source = "pypi"}, | ||||
|     {markers = "platform_machine == 'amd64' or platform_machine == 'x86_64'", version = "=0.15.2", source = "pytorch-cpu"} | ||||
| ] | ||||
| rich = "^13.4.2" | ||||
| ftfy = "^6.1.1" | ||||
| setuptools = "^68.0.0" | ||||
| open-clip-torch = "^2.20.0" | ||||
| 
 | ||||
| [tool.poetry.group.dev.dependencies] | ||||
| mypy = "^1.3.0" | ||||
| @ -62,13 +70,20 @@ warn_untyped_fields = true | ||||
| [[tool.mypy.overrides]] | ||||
| module = [ | ||||
|     "huggingface_hub", | ||||
|     "transformers.pipelines", | ||||
|     "transformers", | ||||
|     "cv2", | ||||
|     "insightface.model_zoo", | ||||
|     "insightface.utils.face_align", | ||||
|     "insightface.utils.storage", | ||||
|     "sentence_transformers", | ||||
|     "sentence_transformers.util", | ||||
|     "onnxruntime", | ||||
|     "optimum", | ||||
|     "optimum.pipelines", | ||||
|     "optimum.onnxruntime", | ||||
|     "clip_server.model.clip", | ||||
|     "clip_server.model.clip_onnx", | ||||
|     "clip_server.model.pretrained_models", | ||||
|     "clip_server.model.tokenization", | ||||
|     "torchvision.transforms", | ||||
|     "aiocache.backends.memory", | ||||
|     "aiocache.lock", | ||||
|     "aiocache.plugins" | ||||
|  | ||||
							
								
								
									
										2
									
								
								machine-learning/requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								machine-learning/requirements.txt
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,2 @@ | ||||
| # requirements to be installed with `--no-deps` flag | ||||
| clip-server==0.8.* | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user