mirror of
				https://github.com/immich-app/immich.git
				synced 2025-10-26 00:14:40 -04:00 
			
		
		
		
	feat(ml)!: switch image classification and CLIP models to ONNX (#3809)
This commit is contained in:
		
							parent
							
								
									8211afb726
								
							
						
					
					
						commit
						165b91b068
					
				
							
								
								
									
										1
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							| @ -171,6 +171,7 @@ jobs: | |||||||
|       - name: Install dependencies |       - name: Install dependencies | ||||||
|         run: | |         run: | | ||||||
|           poetry install --with dev |           poetry install --with dev | ||||||
|  |           poetry run pip install --no-deps -r requirements.txt | ||||||
|       - name: Lint with ruff |       - name: Lint with ruff | ||||||
|         run: | |         run: | | ||||||
|           poetry run ruff check --format=github app |           poetry run ruff check --format=github app | ||||||
|  | |||||||
| @ -10,8 +10,9 @@ RUN poetry config installer.max-workers 10 && \ | |||||||
| RUN python -m venv /opt/venv | RUN python -m venv /opt/venv | ||||||
| ENV VIRTUAL_ENV="/opt/venv" PATH="/opt/venv/bin:${PATH}" | ENV VIRTUAL_ENV="/opt/venv" PATH="/opt/venv/bin:${PATH}" | ||||||
| 
 | 
 | ||||||
| COPY poetry.lock pyproject.toml ./ | COPY poetry.lock pyproject.toml requirements.txt ./ | ||||||
| RUN poetry install --sync --no-interaction --no-ansi --no-root --only main | RUN poetry install --sync --no-interaction --no-ansi --no-root --only main | ||||||
|  | RUN pip install --no-deps -r requirements.txt | ||||||
| 
 | 
 | ||||||
| FROM python:3.11.4-slim-bullseye@sha256:91d194f58f50594cda71dcd2e8fdefd90e7ecc57d07823813b67c8521e565dcd | FROM python:3.11.4-slim-bullseye@sha256:91d194f58f50594cda71dcd2e8fdefd90e7ecc57d07823813b67c8521e565dcd | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -1,3 +1,4 @@ | |||||||
|  | import os | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| 
 | 
 | ||||||
| from pydantic import BaseSettings | from pydantic import BaseSettings | ||||||
| @ -8,8 +9,8 @@ from .schemas import ModelType | |||||||
| class Settings(BaseSettings): | class Settings(BaseSettings): | ||||||
|     cache_folder: str = "/cache" |     cache_folder: str = "/cache" | ||||||
|     classification_model: str = "microsoft/resnet-50" |     classification_model: str = "microsoft/resnet-50" | ||||||
|     clip_image_model: str = "clip-ViT-B-32" |     clip_image_model: str = "ViT-B-32::openai" | ||||||
|     clip_text_model: str = "clip-ViT-B-32" |     clip_text_model: str = "ViT-B-32::openai" | ||||||
|     facial_recognition_model: str = "buffalo_l" |     facial_recognition_model: str = "buffalo_l" | ||||||
|     min_tag_score: float = 0.9 |     min_tag_score: float = 0.9 | ||||||
|     eager_startup: bool = False |     eager_startup: bool = False | ||||||
| @ -19,14 +20,20 @@ class Settings(BaseSettings): | |||||||
|     workers: int = 1 |     workers: int = 1 | ||||||
|     min_face_score: float = 0.7 |     min_face_score: float = 0.7 | ||||||
|     test_full: bool = False |     test_full: bool = False | ||||||
|  |     request_threads: int = os.cpu_count() or 4 | ||||||
|  |     model_inter_op_threads: int = 1 | ||||||
|  |     model_intra_op_threads: int = 2 | ||||||
| 
 | 
 | ||||||
|     class Config: |     class Config: | ||||||
|         env_prefix = "MACHINE_LEARNING_" |         env_prefix = "MACHINE_LEARNING_" | ||||||
|         case_sensitive = False |         case_sensitive = False | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | _clean_name = str.maketrans(":\\/", "___", ".") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def get_cache_dir(model_name: str, model_type: ModelType) -> Path: | def get_cache_dir(model_name: str, model_type: ModelType) -> Path: | ||||||
|     return Path(settings.cache_folder, model_type.value, model_name) |     return Path(settings.cache_folder) / model_type.value / model_name.translate(_clean_name) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| settings = Settings() | settings = Settings() | ||||||
|  | |||||||
| @ -1,4 +1,6 @@ | |||||||
|  | import asyncio | ||||||
| import os | import os | ||||||
|  | from concurrent.futures import ThreadPoolExecutor | ||||||
| from io import BytesIO | from io import BytesIO | ||||||
| from typing import Any | from typing import Any | ||||||
| 
 | 
 | ||||||
| @ -8,6 +10,8 @@ import uvicorn | |||||||
| from fastapi import Body, Depends, FastAPI | from fastapi import Body, Depends, FastAPI | ||||||
| from PIL import Image | from PIL import Image | ||||||
| 
 | 
 | ||||||
|  | from app.models.base import InferenceModel | ||||||
|  | 
 | ||||||
| from .config import settings | from .config import settings | ||||||
| from .models.cache import ModelCache | from .models.cache import ModelCache | ||||||
| from .schemas import ( | from .schemas import ( | ||||||
| @ -25,19 +29,21 @@ app = FastAPI() | |||||||
| 
 | 
 | ||||||
| def init_state() -> None: | def init_state() -> None: | ||||||
|     app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0) |     app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0) | ||||||
|  |     # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code | ||||||
|  |     app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| async def load_models() -> None: | async def load_models() -> None: | ||||||
|     models = [ |     models: list[tuple[str, ModelType, dict[str, Any]]] = [ | ||||||
|         (settings.classification_model, ModelType.IMAGE_CLASSIFICATION), |         (settings.classification_model, ModelType.IMAGE_CLASSIFICATION, {}), | ||||||
|         (settings.clip_image_model, ModelType.CLIP), |         (settings.clip_image_model, ModelType.CLIP, {"mode": "vision"}), | ||||||
|         (settings.clip_text_model, ModelType.CLIP), |         (settings.clip_text_model, ModelType.CLIP, {"mode": "text"}), | ||||||
|         (settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION), |         (settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION, {}), | ||||||
|     ] |     ] | ||||||
| 
 | 
 | ||||||
|     # Get all models |     # Get all models | ||||||
|     for model_name, model_type in models: |     for model_name, model_type, model_kwargs in models: | ||||||
|         await app.state.model_cache.get(model_name, model_type, eager=settings.eager_startup) |         await app.state.model_cache.get(model_name, model_type, eager=settings.eager_startup, **model_kwargs) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.on_event("startup") | @app.on_event("startup") | ||||||
| @ -46,11 +52,16 @@ async def startup_event() -> None: | |||||||
|     await load_models() |     await load_models() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @app.on_event("shutdown") | ||||||
|  | async def shutdown_event() -> None: | ||||||
|  |     app.state.thread_pool.shutdown() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def dep_pil_image(byte_image: bytes = Body(...)) -> Image.Image: | def dep_pil_image(byte_image: bytes = Body(...)) -> Image.Image: | ||||||
|     return Image.open(BytesIO(byte_image)) |     return Image.open(BytesIO(byte_image)) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def dep_cv_image(byte_image: bytes = Body(...)) -> cv2.Mat: | def dep_cv_image(byte_image: bytes = Body(...)) -> np.ndarray[int, np.dtype[Any]]: | ||||||
|     byte_image_np = np.frombuffer(byte_image, np.uint8) |     byte_image_np = np.frombuffer(byte_image, np.uint8) | ||||||
|     return cv2.imdecode(byte_image_np, cv2.IMREAD_COLOR) |     return cv2.imdecode(byte_image_np, cv2.IMREAD_COLOR) | ||||||
| 
 | 
 | ||||||
| @ -74,7 +85,7 @@ async def image_classification( | |||||||
|     image: Image.Image = Depends(dep_pil_image), |     image: Image.Image = Depends(dep_pil_image), | ||||||
| ) -> list[str]: | ) -> list[str]: | ||||||
|     model = await app.state.model_cache.get(settings.classification_model, ModelType.IMAGE_CLASSIFICATION) |     model = await app.state.model_cache.get(settings.classification_model, ModelType.IMAGE_CLASSIFICATION) | ||||||
|     labels = model.predict(image) |     labels = await predict(model, image) | ||||||
|     return labels |     return labels | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -86,8 +97,8 @@ async def image_classification( | |||||||
| async def clip_encode_image( | async def clip_encode_image( | ||||||
|     image: Image.Image = Depends(dep_pil_image), |     image: Image.Image = Depends(dep_pil_image), | ||||||
| ) -> list[float]: | ) -> list[float]: | ||||||
|     model = await app.state.model_cache.get(settings.clip_image_model, ModelType.CLIP) |     model = await app.state.model_cache.get(settings.clip_image_model, ModelType.CLIP, mode="vision") | ||||||
|     embedding = model.predict(image) |     embedding = await predict(model, image) | ||||||
|     return embedding |     return embedding | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -97,8 +108,8 @@ async def clip_encode_image( | |||||||
|     status_code=200, |     status_code=200, | ||||||
| ) | ) | ||||||
| async def clip_encode_text(payload: TextModelRequest) -> list[float]: | async def clip_encode_text(payload: TextModelRequest) -> list[float]: | ||||||
|     model = await app.state.model_cache.get(settings.clip_text_model, ModelType.CLIP) |     model = await app.state.model_cache.get(settings.clip_text_model, ModelType.CLIP, mode="text") | ||||||
|     embedding = model.predict(payload.text) |     embedding = await predict(model, payload.text) | ||||||
|     return embedding |     return embedding | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -111,10 +122,14 @@ async def facial_recognition( | |||||||
|     image: cv2.Mat = Depends(dep_cv_image), |     image: cv2.Mat = Depends(dep_cv_image), | ||||||
| ) -> list[dict[str, Any]]: | ) -> list[dict[str, Any]]: | ||||||
|     model = await app.state.model_cache.get(settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION) |     model = await app.state.model_cache.get(settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION) | ||||||
|     faces = model.predict(image) |     faces = await predict(model, image) | ||||||
|     return faces |     return faces | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | async def predict(model: InferenceModel, inputs: Any) -> Any: | ||||||
|  |     return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     is_dev = os.getenv("NODE_ENV") == "development" |     is_dev = os.getenv("NODE_ENV") == "development" | ||||||
|     uvicorn.run( |     uvicorn.run( | ||||||
|  | |||||||
| @ -1,3 +1,3 @@ | |||||||
| from .clip import CLIPSTEncoder | from .clip import CLIPEncoder | ||||||
| from .facial_recognition import FaceRecognizer | from .facial_recognition import FaceRecognizer | ||||||
| from .image_classification import ImageClassifier | from .image_classification import ImageClassifier | ||||||
|  | |||||||
| @ -1,14 +1,17 @@ | |||||||
| from __future__ import annotations | from __future__ import annotations | ||||||
| 
 | 
 | ||||||
|  | import os | ||||||
|  | import pickle | ||||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from shutil import rmtree | from shutil import rmtree | ||||||
| from typing import Any | from typing import Any | ||||||
| from zipfile import BadZipFile | from zipfile import BadZipFile | ||||||
| 
 | 
 | ||||||
|  | import onnxruntime as ort | ||||||
| from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf  # type: ignore | from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf  # type: ignore | ||||||
| 
 | 
 | ||||||
| from ..config import get_cache_dir | from ..config import get_cache_dir, settings | ||||||
| from ..schemas import ModelType | from ..schemas import ModelType | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -16,12 +19,31 @@ class InferenceModel(ABC): | |||||||
|     _model_type: ModelType |     _model_type: ModelType | ||||||
| 
 | 
 | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, model_name: str, cache_dir: Path | str | None = None, eager: bool = True, **model_kwargs: Any |         self, | ||||||
|  |         model_name: str, | ||||||
|  |         cache_dir: Path | str | None = None, | ||||||
|  |         eager: bool = True, | ||||||
|  |         inter_op_num_threads: int = settings.model_inter_op_threads, | ||||||
|  |         intra_op_num_threads: int = settings.model_intra_op_threads, | ||||||
|  |         **model_kwargs: Any, | ||||||
|     ) -> None: |     ) -> None: | ||||||
|         self.model_name = model_name |         self.model_name = model_name | ||||||
|         self._loaded = False |         self._loaded = False | ||||||
|         self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type) |         self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type) | ||||||
|         loader = self.load if eager else self.download |         loader = self.load if eager else self.download | ||||||
|  | 
 | ||||||
|  |         self.providers = model_kwargs.pop("providers", ["CPUExecutionProvider"]) | ||||||
|  |         #  don't pre-allocate more memory than needed | ||||||
|  |         self.provider_options = model_kwargs.pop( | ||||||
|  |             "provider_options", [{"arena_extend_strategy": "kSameAsRequested"}] * len(self.providers) | ||||||
|  |         ) | ||||||
|  |         self.sess_options = PicklableSessionOptions() | ||||||
|  |         # avoid thread contention between models | ||||||
|  |         if inter_op_num_threads > 1: | ||||||
|  |             self.sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL | ||||||
|  |         self.sess_options.inter_op_num_threads = inter_op_num_threads | ||||||
|  |         self.sess_options.intra_op_num_threads = intra_op_num_threads | ||||||
|  | 
 | ||||||
|         try: |         try: | ||||||
|             loader(**model_kwargs) |             loader(**model_kwargs) | ||||||
|         except (OSError, InvalidProtobuf, BadZipFile): |         except (OSError, InvalidProtobuf, BadZipFile): | ||||||
| @ -30,6 +52,7 @@ class InferenceModel(ABC): | |||||||
| 
 | 
 | ||||||
|     def download(self, **model_kwargs: Any) -> None: |     def download(self, **model_kwargs: Any) -> None: | ||||||
|         if not self.cached: |         if not self.cached: | ||||||
|  |             print(f"Downloading {self.model_type.value.replace('_', ' ')} model. This may take a while...") | ||||||
|             self._download(**model_kwargs) |             self._download(**model_kwargs) | ||||||
| 
 | 
 | ||||||
|     def load(self, **model_kwargs: Any) -> None: |     def load(self, **model_kwargs: Any) -> None: | ||||||
| @ -39,6 +62,7 @@ class InferenceModel(ABC): | |||||||
| 
 | 
 | ||||||
|     def predict(self, inputs: Any) -> Any: |     def predict(self, inputs: Any) -> Any: | ||||||
|         if not self._loaded: |         if not self._loaded: | ||||||
|  |             print(f"Loading {self.model_type.value.replace('_', ' ')} model...") | ||||||
|             self.load() |             self.load() | ||||||
|         return self._predict(inputs) |         return self._predict(inputs) | ||||||
| 
 | 
 | ||||||
| @ -89,3 +113,14 @@ class InferenceModel(ABC): | |||||||
|         else: |         else: | ||||||
|             self.cache_dir.unlink() |             self.cache_dir.unlink() | ||||||
|         self.cache_dir.mkdir(parents=True, exist_ok=True) |         self.cache_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # HF deep copies configs, so we need to make session options picklable | ||||||
|  | class PicklableSessionOptions(ort.SessionOptions): | ||||||
|  |     def __getstate__(self) -> bytes: | ||||||
|  |         return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))]) | ||||||
|  | 
 | ||||||
|  |     def __setstate__(self, state: Any) -> None: | ||||||
|  |         self.__init__()  # type: ignore | ||||||
|  |         for attr, val in pickle.loads(state): | ||||||
|  |             setattr(self, attr, val) | ||||||
|  | |||||||
| @ -46,7 +46,7 @@ class ModelCache: | |||||||
|             model: The requested model. |             model: The requested model. | ||||||
|         """ |         """ | ||||||
| 
 | 
 | ||||||
|         key = self.cache.build_key(model_name, model_type.value) |         key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}" | ||||||
|         async with OptimisticLock(self.cache, key) as lock: |         async with OptimisticLock(self.cache, key) as lock: | ||||||
|             model = await self.cache.get(key) |             model = await self.cache.get(key) | ||||||
|             if model is None: |             if model is None: | ||||||
|  | |||||||
| @ -1,31 +1,141 @@ | |||||||
| from typing import Any | import os | ||||||
|  | import zipfile | ||||||
|  | from typing import Any, Literal | ||||||
| 
 | 
 | ||||||
|  | import onnxruntime as ort | ||||||
|  | import torch | ||||||
|  | from clip_server.model.clip import BICUBIC, _convert_image_to_rgb | ||||||
|  | from clip_server.model.clip_onnx import _MODELS, _S3_BUCKET_V2, CLIPOnnxModel, download_model | ||||||
|  | from clip_server.model.pretrained_models import _VISUAL_MODEL_IMAGE_SIZE | ||||||
|  | from clip_server.model.tokenization import Tokenizer | ||||||
| from PIL.Image import Image | from PIL.Image import Image | ||||||
| from sentence_transformers import SentenceTransformer | from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor | ||||||
| from sentence_transformers.util import snapshot_download |  | ||||||
| 
 | 
 | ||||||
| from ..schemas import ModelType | from ..schemas import ModelType | ||||||
| from .base import InferenceModel | from .base import InferenceModel | ||||||
| 
 | 
 | ||||||
|  | _ST_TO_JINA_MODEL_NAME = { | ||||||
|  |     "clip-ViT-B-16": "ViT-B-16::openai", | ||||||
|  |     "clip-ViT-B-32": "ViT-B-32::openai", | ||||||
|  |     "clip-ViT-B-32-multilingual-v1": "M-CLIP/XLM-Roberta-Large-Vit-B-32", | ||||||
|  |     "clip-ViT-L-14": "ViT-L-14::openai", | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
| class CLIPSTEncoder(InferenceModel): | 
 | ||||||
|  | class CLIPEncoder(InferenceModel): | ||||||
|     _model_type = ModelType.CLIP |     _model_type = ModelType.CLIP | ||||||
| 
 | 
 | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         model_name: str, | ||||||
|  |         cache_dir: str | None = None, | ||||||
|  |         mode: Literal["text", "vision"] | None = None, | ||||||
|  |         **model_kwargs: Any, | ||||||
|  |     ) -> None: | ||||||
|  |         if mode is not None and mode not in ("text", "vision"): | ||||||
|  |             raise ValueError(f"Mode must be 'text', 'vision', or omitted; got '{mode}'") | ||||||
|  |         if "vit-b" not in model_name.lower(): | ||||||
|  |             raise ValueError(f"Only ViT-B models are currently supported; got '{model_name}'") | ||||||
|  |         self.mode = mode | ||||||
|  |         jina_model_name = self._get_jina_model_name(model_name) | ||||||
|  |         super().__init__(jina_model_name, cache_dir, **model_kwargs) | ||||||
|  | 
 | ||||||
|     def _download(self, **model_kwargs: Any) -> None: |     def _download(self, **model_kwargs: Any) -> None: | ||||||
|         repo_id = self.model_name if "/" in self.model_name else f"sentence-transformers/{self.model_name}" |         models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name] | ||||||
|         snapshot_download( |         text_onnx_path = self.cache_dir / "textual.onnx" | ||||||
|             cache_dir=self.cache_dir, |         vision_onnx_path = self.cache_dir / "visual.onnx" | ||||||
|             repo_id=repo_id, | 
 | ||||||
|             library_name="sentence-transformers", |         if not text_onnx_path.is_file(): | ||||||
|             ignore_files=["flax_model.msgpack", "rust_model.ot", "tf_model.h5"], |             self._download_model(*models[0]) | ||||||
|         ) | 
 | ||||||
|  |         if not vision_onnx_path.is_file(): | ||||||
|  |             self._download_model(*models[1]) | ||||||
| 
 | 
 | ||||||
|     def _load(self, **model_kwargs: Any) -> None: |     def _load(self, **model_kwargs: Any) -> None: | ||||||
|         self.model = SentenceTransformer( |         if self.mode == "text" or self.mode is None: | ||||||
|             self.model_name, |             self.text_model = ort.InferenceSession( | ||||||
|             cache_folder=self.cache_dir.as_posix(), |                 self.cache_dir / "textual.onnx", | ||||||
|             **model_kwargs, |                 sess_options=self.sess_options, | ||||||
|         ) |                 providers=self.providers, | ||||||
|  |                 provider_options=self.provider_options, | ||||||
|  |             ) | ||||||
|  |             self.text_outputs = [output.name for output in self.text_model.get_outputs()] | ||||||
|  |             self.tokenizer = Tokenizer(self.model_name) | ||||||
|  | 
 | ||||||
|  |         if self.mode == "vision" or self.mode is None: | ||||||
|  |             self.vision_model = ort.InferenceSession( | ||||||
|  |                 self.cache_dir / "visual.onnx", | ||||||
|  |                 sess_options=self.sess_options, | ||||||
|  |                 providers=self.providers, | ||||||
|  |                 provider_options=self.provider_options, | ||||||
|  |             ) | ||||||
|  |             self.vision_outputs = [output.name for output in self.vision_model.get_outputs()] | ||||||
|  | 
 | ||||||
|  |             image_size = _VISUAL_MODEL_IMAGE_SIZE[CLIPOnnxModel.get_model_name(self.model_name)] | ||||||
|  |             self.transform = _transform_pil_image(image_size) | ||||||
| 
 | 
 | ||||||
|     def _predict(self, image_or_text: Image | str) -> list[float]: |     def _predict(self, image_or_text: Image | str) -> list[float]: | ||||||
|         return self.model.encode(image_or_text).tolist() |         match image_or_text: | ||||||
|  |             case Image(): | ||||||
|  |                 if self.mode == "text": | ||||||
|  |                     raise TypeError("Cannot encode image as text-only model") | ||||||
|  |                 pixel_values = self.transform(image_or_text) | ||||||
|  |                 assert isinstance(pixel_values, torch.Tensor) | ||||||
|  |                 pixel_values = torch.unsqueeze(pixel_values, 0).numpy() | ||||||
|  |                 outputs = self.vision_model.run(self.vision_outputs, {"pixel_values": pixel_values}) | ||||||
|  |             case str(): | ||||||
|  |                 if self.mode == "vision": | ||||||
|  |                     raise TypeError("Cannot encode text as vision-only model") | ||||||
|  |                 text_inputs: dict[str, torch.Tensor] = self.tokenizer(image_or_text) | ||||||
|  |                 inputs = { | ||||||
|  |                     "input_ids": text_inputs["input_ids"].int().numpy(), | ||||||
|  |                     "attention_mask": text_inputs["attention_mask"].int().numpy(), | ||||||
|  |                 } | ||||||
|  |                 outputs = self.text_model.run(self.text_outputs, inputs) | ||||||
|  |             case _: | ||||||
|  |                 raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}") | ||||||
|  | 
 | ||||||
|  |         return outputs[0][0].tolist() | ||||||
|  | 
 | ||||||
|  |     def _get_jina_model_name(self, model_name: str) -> str: | ||||||
|  |         if model_name in _MODELS: | ||||||
|  |             return model_name | ||||||
|  |         elif model_name in _ST_TO_JINA_MODEL_NAME: | ||||||
|  |             print( | ||||||
|  |                 (f"Warning: Sentence-Transformer model names such as '{model_name}' are no longer supported."), | ||||||
|  |                 (f"Using '{_ST_TO_JINA_MODEL_NAME[model_name]}' instead as it is the best match for '{model_name}'."), | ||||||
|  |             ) | ||||||
|  |             return _ST_TO_JINA_MODEL_NAME[model_name] | ||||||
|  |         else: | ||||||
|  |             raise ValueError(f"Unknown model name {model_name}.") | ||||||
|  | 
 | ||||||
|  |     def _download_model(self, model_name: str, model_md5: str) -> bool: | ||||||
|  |         # downloading logic is adapted from clip-server's CLIPOnnxModel class | ||||||
|  |         download_model( | ||||||
|  |             url=_S3_BUCKET_V2 + model_name, | ||||||
|  |             target_folder=self.cache_dir.as_posix(), | ||||||
|  |             md5sum=model_md5, | ||||||
|  |             with_resume=True, | ||||||
|  |         ) | ||||||
|  |         file = self.cache_dir / model_name.split("/")[1] | ||||||
|  |         if file.suffix == ".zip": | ||||||
|  |             with zipfile.ZipFile(file, "r") as zip_ref: | ||||||
|  |                 zip_ref.extractall(self.cache_dir) | ||||||
|  |             os.remove(file) | ||||||
|  |         return True | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # same as `_transform_blob` without `_blob2image` | ||||||
|  | def _transform_pil_image(n_px: int) -> Compose: | ||||||
|  |     return Compose( | ||||||
|  |         [ | ||||||
|  |             Resize(n_px, interpolation=BICUBIC), | ||||||
|  |             CenterCrop(n_px), | ||||||
|  |             _convert_image_to_rgb, | ||||||
|  |             ToTensor(), | ||||||
|  |             Normalize( | ||||||
|  |                 (0.48145466, 0.4578275, 0.40821073), | ||||||
|  |                 (0.26862954, 0.26130258, 0.27577711), | ||||||
|  |             ), | ||||||
|  |         ] | ||||||
|  |     ) | ||||||
|  | |||||||
| @ -4,6 +4,7 @@ from typing import Any | |||||||
| 
 | 
 | ||||||
| import cv2 | import cv2 | ||||||
| import numpy as np | import numpy as np | ||||||
|  | import onnxruntime as ort | ||||||
| from insightface.model_zoo import ArcFaceONNX, RetinaFace | from insightface.model_zoo import ArcFaceONNX, RetinaFace | ||||||
| from insightface.utils.face_align import norm_crop | from insightface.utils.face_align import norm_crop | ||||||
| from insightface.utils.storage import BASE_REPO_URL, download_file | from insightface.utils.storage import BASE_REPO_URL, download_file | ||||||
| @ -42,15 +43,31 @@ class FaceRecognizer(InferenceModel): | |||||||
|             rec_file = next(self.cache_dir.glob("w600k_*.onnx")) |             rec_file = next(self.cache_dir.glob("w600k_*.onnx")) | ||||||
|         except StopIteration: |         except StopIteration: | ||||||
|             raise FileNotFoundError("Facial recognition models not found in cache directory") |             raise FileNotFoundError("Facial recognition models not found in cache directory") | ||||||
|         self.det_model = RetinaFace(det_file.as_posix()) | 
 | ||||||
|         self.rec_model = ArcFaceONNX(rec_file.as_posix()) |         self.det_model = RetinaFace( | ||||||
|  |             session=ort.InferenceSession( | ||||||
|  |                 det_file.as_posix(), | ||||||
|  |                 sess_options=self.sess_options, | ||||||
|  |                 providers=self.providers, | ||||||
|  |                 provider_options=self.provider_options, | ||||||
|  |             ), | ||||||
|  |         ) | ||||||
|  |         self.rec_model = ArcFaceONNX( | ||||||
|  |             rec_file.as_posix(), | ||||||
|  |             session=ort.InferenceSession( | ||||||
|  |                 rec_file.as_posix(), | ||||||
|  |                 sess_options=self.sess_options, | ||||||
|  |                 providers=self.providers, | ||||||
|  |                 provider_options=self.provider_options, | ||||||
|  |             ), | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         self.det_model.prepare( |         self.det_model.prepare( | ||||||
|             ctx_id=-1, |             ctx_id=0, | ||||||
|             det_thresh=self.min_score, |             det_thresh=self.min_score, | ||||||
|             input_size=(640, 640), |             input_size=(640, 640), | ||||||
|         ) |         ) | ||||||
|         self.rec_model.prepare(ctx_id=-1) |         self.rec_model.prepare(ctx_id=0) | ||||||
| 
 | 
 | ||||||
|     def _predict(self, image: cv2.Mat) -> list[dict[str, Any]]: |     def _predict(self, image: cv2.Mat) -> list[dict[str, Any]]: | ||||||
|         bboxes, kpss = self.det_model.detect(image) |         bboxes, kpss = self.det_model.detect(image) | ||||||
|  | |||||||
| @ -2,8 +2,10 @@ from pathlib import Path | |||||||
| from typing import Any | from typing import Any | ||||||
| 
 | 
 | ||||||
| from huggingface_hub import snapshot_download | from huggingface_hub import snapshot_download | ||||||
|  | from optimum.onnxruntime import ORTModelForImageClassification | ||||||
|  | from optimum.pipelines import pipeline | ||||||
| from PIL.Image import Image | from PIL.Image import Image | ||||||
| from transformers.pipelines import pipeline | from transformers import AutoImageProcessor | ||||||
| 
 | 
 | ||||||
| from ..config import settings | from ..config import settings | ||||||
| from ..schemas import ModelType | from ..schemas import ModelType | ||||||
| @ -25,15 +27,34 @@ class ImageClassifier(InferenceModel): | |||||||
| 
 | 
 | ||||||
|     def _download(self, **model_kwargs: Any) -> None: |     def _download(self, **model_kwargs: Any) -> None: | ||||||
|         snapshot_download( |         snapshot_download( | ||||||
|             cache_dir=self.cache_dir, repo_id=self.model_name, allow_patterns=["*.bin", "*.json", "*.txt"] |             cache_dir=self.cache_dir, | ||||||
|  |             repo_id=self.model_name, | ||||||
|  |             allow_patterns=["*.bin", "*.json", "*.txt"], | ||||||
|  |             local_dir=self.cache_dir, | ||||||
|  |             local_dir_use_symlinks=True, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def _load(self, **model_kwargs: Any) -> None: |     def _load(self, **model_kwargs: Any) -> None: | ||||||
|         self.model = pipeline( |         processor = AutoImageProcessor.from_pretrained(self.cache_dir) | ||||||
|             self.model_type.value, |         model_kwargs |= { | ||||||
|             self.model_name, |             "cache_dir": self.cache_dir, | ||||||
|             model_kwargs={"cache_dir": self.cache_dir, **model_kwargs}, |             "provider": self.providers[0], | ||||||
|         ) |             "provider_options": self.provider_options[0], | ||||||
|  |             "session_options": self.sess_options, | ||||||
|  |         } | ||||||
|  |         model_path = self.cache_dir / "model.onnx" | ||||||
|  | 
 | ||||||
|  |         if model_path.exists(): | ||||||
|  |             model = ORTModelForImageClassification.from_pretrained(self.cache_dir, **model_kwargs) | ||||||
|  |             self.model = pipeline(self.model_type.value, model, feature_extractor=processor) | ||||||
|  |         else: | ||||||
|  |             self.sess_options.optimized_model_filepath = model_path.as_posix() | ||||||
|  |             self.model = pipeline( | ||||||
|  |                 self.model_type.value, | ||||||
|  |                 self.model_name, | ||||||
|  |                 model_kwargs=model_kwargs, | ||||||
|  |                 feature_extractor=processor, | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|     def _predict(self, image: Image) -> list[str]: |     def _predict(self, image: Image) -> list[str]: | ||||||
|         predictions: list[dict[str, Any]] = self.model(image)  # type: ignore |         predictions: list[dict[str, Any]] = self.model(image)  # type: ignore | ||||||
|  | |||||||
| @ -1,17 +1,20 @@ | |||||||
|  | import pickle | ||||||
| from io import BytesIO | from io import BytesIO | ||||||
| from typing import TypeAlias | from typing import TypeAlias | ||||||
| from unittest import mock | from unittest import mock | ||||||
| 
 | 
 | ||||||
| import cv2 | import cv2 | ||||||
| import numpy as np | import numpy as np | ||||||
|  | import onnxruntime as ort | ||||||
| import pytest | import pytest | ||||||
| from fastapi.testclient import TestClient | from fastapi.testclient import TestClient | ||||||
| from PIL import Image | from PIL import Image | ||||||
| from pytest_mock import MockerFixture | from pytest_mock import MockerFixture | ||||||
| 
 | 
 | ||||||
| from .config import settings | from .config import settings | ||||||
|  | from .models.base import PicklableSessionOptions | ||||||
| from .models.cache import ModelCache | from .models.cache import ModelCache | ||||||
| from .models.clip import CLIPSTEncoder | from .models.clip import CLIPEncoder | ||||||
| from .models.facial_recognition import FaceRecognizer | from .models.facial_recognition import FaceRecognizer | ||||||
| from .models.image_classification import ImageClassifier | from .models.image_classification import ImageClassifier | ||||||
| from .schemas import ModelType | from .schemas import ModelType | ||||||
| @ -72,45 +75,47 @@ class TestCLIP: | |||||||
|     embedding = np.random.rand(512).astype(np.float32) |     embedding = np.random.rand(512).astype(np.float32) | ||||||
| 
 | 
 | ||||||
|     def test_eager_init(self, mocker: MockerFixture) -> None: |     def test_eager_init(self, mocker: MockerFixture) -> None: | ||||||
|         mocker.patch.object(CLIPSTEncoder, "download") |         mocker.patch.object(CLIPEncoder, "download") | ||||||
|         mock_load = mocker.patch.object(CLIPSTEncoder, "load") |         mock_load = mocker.patch.object(CLIPEncoder, "load") | ||||||
|         clip_model = CLIPSTEncoder("test_model_name", cache_dir="test_cache", eager=True, test_arg="test_arg") |         clip_model = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", eager=True, test_arg="test_arg") | ||||||
| 
 | 
 | ||||||
|         assert clip_model.model_name == "test_model_name" |         assert clip_model.model_name == "ViT-B-32::openai" | ||||||
|         mock_load.assert_called_once_with(test_arg="test_arg") |         mock_load.assert_called_once_with(test_arg="test_arg") | ||||||
| 
 | 
 | ||||||
|     def test_lazy_init(self, mocker: MockerFixture) -> None: |     def test_lazy_init(self, mocker: MockerFixture) -> None: | ||||||
|         mock_download = mocker.patch.object(CLIPSTEncoder, "download") |         mock_download = mocker.patch.object(CLIPEncoder, "download") | ||||||
|         mock_load = mocker.patch.object(CLIPSTEncoder, "load") |         mock_load = mocker.patch.object(CLIPEncoder, "load") | ||||||
|         clip_model = CLIPSTEncoder("test_model_name", cache_dir="test_cache", eager=False, test_arg="test_arg") |         clip_model = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", eager=False, test_arg="test_arg") | ||||||
| 
 | 
 | ||||||
|         assert clip_model.model_name == "test_model_name" |         assert clip_model.model_name == "ViT-B-32::openai" | ||||||
|         mock_download.assert_called_once_with(test_arg="test_arg") |         mock_download.assert_called_once_with(test_arg="test_arg") | ||||||
|         mock_load.assert_not_called() |         mock_load.assert_not_called() | ||||||
| 
 | 
 | ||||||
|     def test_basic_image(self, pil_image: Image.Image, mocker: MockerFixture) -> None: |     def test_basic_image(self, pil_image: Image.Image, mocker: MockerFixture) -> None: | ||||||
|         mocker.patch.object(CLIPSTEncoder, "load") |         mocker.patch.object(CLIPEncoder, "download") | ||||||
|         clip_encoder = CLIPSTEncoder("test_model_name", cache_dir="test_cache") |         mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True) | ||||||
|         clip_encoder.model = mock.Mock() |         mocked.return_value.run.return_value = [[self.embedding]] | ||||||
|         clip_encoder.model.encode.return_value = self.embedding |         clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="vision") | ||||||
|  |         assert clip_encoder.mode == "vision" | ||||||
|         embedding = clip_encoder.predict(pil_image) |         embedding = clip_encoder.predict(pil_image) | ||||||
| 
 | 
 | ||||||
|         assert isinstance(embedding, list) |         assert isinstance(embedding, list) | ||||||
|         assert len(embedding) == 512 |         assert len(embedding) == 512 | ||||||
|         assert all([isinstance(num, float) for num in embedding]) |         assert all([isinstance(num, float) for num in embedding]) | ||||||
|         clip_encoder.model.encode.assert_called_once() |         clip_encoder.vision_model.run.assert_called_once() | ||||||
| 
 | 
 | ||||||
|     def test_basic_text(self, mocker: MockerFixture) -> None: |     def test_basic_text(self, mocker: MockerFixture) -> None: | ||||||
|         mocker.patch.object(CLIPSTEncoder, "load") |         mocker.patch.object(CLIPEncoder, "download") | ||||||
|         clip_encoder = CLIPSTEncoder("test_model_name", cache_dir="test_cache") |         mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True) | ||||||
|         clip_encoder.model = mock.Mock() |         mocked.return_value.run.return_value = [[self.embedding]] | ||||||
|         clip_encoder.model.encode.return_value = self.embedding |         clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="text") | ||||||
|  |         assert clip_encoder.mode == "text" | ||||||
|         embedding = clip_encoder.predict("test search query") |         embedding = clip_encoder.predict("test search query") | ||||||
| 
 | 
 | ||||||
|         assert isinstance(embedding, list) |         assert isinstance(embedding, list) | ||||||
|         assert len(embedding) == 512 |         assert len(embedding) == 512 | ||||||
|         assert all([isinstance(num, float) for num in embedding]) |         assert all([isinstance(num, float) for num in embedding]) | ||||||
|         clip_encoder.model.encode.assert_called_once() |         clip_encoder.text_model.run.assert_called_once() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TestFaceRecognition: | class TestFaceRecognition: | ||||||
| @ -254,3 +259,13 @@ class TestEndpoints: | |||||||
|             headers=headers, |             headers=headers, | ||||||
|         ) |         ) | ||||||
|         assert response.status_code == 200 |         assert response.status_code == 200 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_sess_options() -> None: | ||||||
|  |     sess_options = PicklableSessionOptions() | ||||||
|  |     sess_options.intra_op_num_threads = 1 | ||||||
|  |     sess_options.inter_op_num_threads = 1 | ||||||
|  |     pickled = pickle.dumps(sess_options) | ||||||
|  |     unpickled = pickle.loads(pickled) | ||||||
|  |     assert unpickled.intra_op_num_threads == 1 | ||||||
|  |     assert unpickled.inter_op_num_threads == 1 | ||||||
|  | |||||||
							
								
								
									
										1739
									
								
								machine-learning/poetry.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										1739
									
								
								machine-learning/poetry.lock
									
									
									
										generated
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -13,7 +13,6 @@ torch = [ | |||||||
|     {markers = "platform_machine == 'amd64' or platform_machine == 'x86_64'", version = "=2.0.1", source = "pytorch-cpu"} |     {markers = "platform_machine == 'amd64' or platform_machine == 'x86_64'", version = "=2.0.1", source = "pytorch-cpu"} | ||||||
| ] | ] | ||||||
| transformers = "^4.29.2" | transformers = "^4.29.2" | ||||||
| sentence-transformers = "^2.2.2" |  | ||||||
| onnxruntime = "^1.15.0" | onnxruntime = "^1.15.0" | ||||||
| insightface = "^0.7.3" | insightface = "^0.7.3" | ||||||
| opencv-python-headless = "^4.7.0.72" | opencv-python-headless = "^4.7.0.72" | ||||||
| @ -22,6 +21,15 @@ fastapi = "^0.95.2" | |||||||
| uvicorn = {extras = ["standard"], version = "^0.22.0"} | uvicorn = {extras = ["standard"], version = "^0.22.0"} | ||||||
| pydantic = "^1.10.8" | pydantic = "^1.10.8" | ||||||
| aiocache = "^0.12.1" | aiocache = "^0.12.1" | ||||||
|  | optimum = "^1.9.1" | ||||||
|  | torchvision = [ | ||||||
|  |     {markers = "platform_machine == 'arm64' or platform_machine == 'aarch64'", version = "=0.15.2", source = "pypi"}, | ||||||
|  |     {markers = "platform_machine == 'amd64' or platform_machine == 'x86_64'", version = "=0.15.2", source = "pytorch-cpu"} | ||||||
|  | ] | ||||||
|  | rich = "^13.4.2" | ||||||
|  | ftfy = "^6.1.1" | ||||||
|  | setuptools = "^68.0.0" | ||||||
|  | open-clip-torch = "^2.20.0" | ||||||
| 
 | 
 | ||||||
| [tool.poetry.group.dev.dependencies] | [tool.poetry.group.dev.dependencies] | ||||||
| mypy = "^1.3.0" | mypy = "^1.3.0" | ||||||
| @ -62,13 +70,20 @@ warn_untyped_fields = true | |||||||
| [[tool.mypy.overrides]] | [[tool.mypy.overrides]] | ||||||
| module = [ | module = [ | ||||||
|     "huggingface_hub", |     "huggingface_hub", | ||||||
|     "transformers.pipelines", |     "transformers", | ||||||
|     "cv2", |     "cv2", | ||||||
|     "insightface.model_zoo", |     "insightface.model_zoo", | ||||||
|     "insightface.utils.face_align", |     "insightface.utils.face_align", | ||||||
|     "insightface.utils.storage", |     "insightface.utils.storage", | ||||||
|     "sentence_transformers", |     "onnxruntime", | ||||||
|     "sentence_transformers.util", |     "optimum", | ||||||
|  |     "optimum.pipelines", | ||||||
|  |     "optimum.onnxruntime", | ||||||
|  |     "clip_server.model.clip", | ||||||
|  |     "clip_server.model.clip_onnx", | ||||||
|  |     "clip_server.model.pretrained_models", | ||||||
|  |     "clip_server.model.tokenization", | ||||||
|  |     "torchvision.transforms", | ||||||
|     "aiocache.backends.memory", |     "aiocache.backends.memory", | ||||||
|     "aiocache.lock", |     "aiocache.lock", | ||||||
|     "aiocache.plugins" |     "aiocache.plugins" | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								machine-learning/requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								machine-learning/requirements.txt
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,2 @@ | |||||||
|  | # requirements to be installed with `--no-deps` flag | ||||||
|  | clip-server==0.8.* | ||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user