mirror of
				https://github.com/immich-app/immich.git
				synced 2025-10-31 02:27:08 -04:00 
			
		
		
		
	
		
			
				
	
	
		
			257 lines
		
	
	
		
			8.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			257 lines
		
	
	
		
			8.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import asyncio
 | |
| import gc
 | |
| import os
 | |
| import signal
 | |
| import threading
 | |
| import time
 | |
| from concurrent.futures import ThreadPoolExecutor
 | |
| from contextlib import asynccontextmanager
 | |
| from functools import partial
 | |
| from typing import Any, AsyncGenerator, Callable, Iterator
 | |
| from zipfile import BadZipFile
 | |
| 
 | |
| import orjson
 | |
| from fastapi import Depends, FastAPI, File, Form, HTTPException
 | |
| from fastapi.responses import ORJSONResponse, PlainTextResponse
 | |
| from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
 | |
| from PIL.Image import Image
 | |
| from pydantic import ValidationError
 | |
| from starlette.formparsers import MultiPartParser
 | |
| 
 | |
| from immich_ml.models import get_model_deps
 | |
| from immich_ml.models.base import InferenceModel
 | |
| from immich_ml.models.transforms import decode_pil
 | |
| 
 | |
| from .config import PreloadModelData, log, settings
 | |
| from .models.cache import ModelCache
 | |
| from .schemas import (
 | |
|     InferenceEntries,
 | |
|     InferenceEntry,
 | |
|     InferenceResponse,
 | |
|     ModelFormat,
 | |
|     ModelIdentity,
 | |
|     ModelTask,
 | |
|     ModelType,
 | |
|     PipelineRequest,
 | |
|     T,
 | |
| )
 | |
| 
 | |
| MultiPartParser.max_file_size = 2**26  # spools to disk if payload is 64 MiB or larger
 | |
| 
 | |
| model_cache = ModelCache(revalidate=settings.model_ttl > 0)
 | |
| thread_pool: ThreadPoolExecutor | None = None
 | |
| lock = threading.Lock()
 | |
| active_requests = 0
 | |
| last_called: float | None = None
 | |
| 
 | |
| 
 | |
| @asynccontextmanager
 | |
| async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
 | |
|     global thread_pool
 | |
|     log.info(
 | |
|         (
 | |
|             "Created in-memory cache with unloading "
 | |
|             f"{f'after {settings.model_ttl}s of inactivity' if settings.model_ttl > 0 else 'disabled'}."
 | |
|         )
 | |
|     )
 | |
| 
 | |
|     try:
 | |
|         if settings.request_threads > 0:
 | |
|             # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
 | |
|             thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
 | |
|             log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
 | |
|         if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0:
 | |
|             asyncio.ensure_future(idle_shutdown_task())
 | |
|         if settings.preload is not None:
 | |
|             await preload_models(settings.preload)
 | |
|         yield
 | |
|     finally:
 | |
|         log.handlers.clear()
 | |
|         for model in model_cache.cache._cache.values():
 | |
|             del model
 | |
|         if thread_pool is not None:
 | |
|             thread_pool.shutdown()
 | |
|         gc.collect()
 | |
| 
 | |
| 
 | |
| async def preload_models(preload: PreloadModelData) -> None:
 | |
|     log.info(f"Preloading models: clip:{preload.clip} facial_recognition:{preload.facial_recognition}")
 | |
| 
 | |
|     async def load_models(model_string: str, model_type: ModelType, model_task: ModelTask) -> None:
 | |
|         for model_name in model_string.split(","):
 | |
|             model_name = model_name.strip()
 | |
|             model = await model_cache.get(model_name, model_type, model_task)
 | |
|             await load(model)
 | |
| 
 | |
|     if preload.clip.textual is not None:
 | |
|         await load_models(preload.clip.textual, ModelType.TEXTUAL, ModelTask.SEARCH)
 | |
| 
 | |
|     if preload.clip.visual is not None:
 | |
|         await load_models(preload.clip.visual, ModelType.VISUAL, ModelTask.SEARCH)
 | |
| 
 | |
|     if preload.facial_recognition.detection is not None:
 | |
|         await load_models(
 | |
|             preload.facial_recognition.detection,
 | |
|             ModelType.DETECTION,
 | |
|             ModelTask.FACIAL_RECOGNITION,
 | |
|         )
 | |
| 
 | |
|     if preload.facial_recognition.recognition is not None:
 | |
|         await load_models(
 | |
|             preload.facial_recognition.recognition,
 | |
|             ModelType.RECOGNITION,
 | |
|             ModelTask.FACIAL_RECOGNITION,
 | |
|         )
 | |
| 
 | |
|     if preload.clip_fallback is not None:
 | |
|         log.warning(
 | |
|             "Deprecated env variable: 'MACHINE_LEARNING_PRELOAD__CLIP'. "
 | |
|             "Use 'MACHINE_LEARNING_PRELOAD__CLIP__TEXTUAL' and "
 | |
|             "'MACHINE_LEARNING_PRELOAD__CLIP__VISUAL' instead."
 | |
|         )
 | |
| 
 | |
|     if preload.facial_recognition_fallback is not None:
 | |
|         log.warning(
 | |
|             "Deprecated env variable: 'MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION'. "
 | |
|             "Use 'MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__DETECTION' and "
 | |
|             "'MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__RECOGNITION' instead."
 | |
|         )
 | |
| 
 | |
| 
 | |
| def update_state() -> Iterator[None]:
 | |
|     global active_requests, last_called
 | |
|     active_requests += 1
 | |
|     last_called = time.time()
 | |
|     try:
 | |
|         yield
 | |
|     finally:
 | |
|         active_requests -= 1
 | |
| 
 | |
| 
 | |
| def get_entries(entries: str = Form()) -> InferenceEntries:
 | |
|     try:
 | |
|         request: PipelineRequest = orjson.loads(entries)
 | |
|         without_deps: list[InferenceEntry] = []
 | |
|         with_deps: list[InferenceEntry] = []
 | |
|         for task, types in request.items():
 | |
|             for type, entry in types.items():
 | |
|                 parsed: InferenceEntry = {
 | |
|                     "name": entry["modelName"],
 | |
|                     "task": task,
 | |
|                     "type": type,
 | |
|                     "options": entry.get("options", {}),
 | |
|                 }
 | |
|                 dep = get_model_deps(parsed["name"], type, task)
 | |
|                 (with_deps if dep else without_deps).append(parsed)
 | |
|         return without_deps, with_deps
 | |
|     except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e:
 | |
|         log.error(f"Invalid request format: {e}")
 | |
|         raise HTTPException(422, "Invalid request format.")
 | |
| 
 | |
| 
 | |
| app = FastAPI(lifespan=lifespan)
 | |
| 
 | |
| 
 | |
| @app.get("/")
 | |
| async def root() -> ORJSONResponse:
 | |
|     return ORJSONResponse({"message": "Immich ML"})
 | |
| 
 | |
| 
 | |
| @app.get("/ping")
 | |
| def ping() -> PlainTextResponse:
 | |
|     return PlainTextResponse("pong")
 | |
| 
 | |
| 
 | |
| @app.post("/predict", dependencies=[Depends(update_state)])
 | |
| async def predict(
 | |
|     entries: InferenceEntries = Depends(get_entries),
 | |
|     image: bytes | None = File(default=None),
 | |
|     text: str | None = Form(default=None),
 | |
| ) -> Any:
 | |
|     if image is not None:
 | |
|         inputs: Image | str = await run(lambda: decode_pil(image))
 | |
|     elif text is not None:
 | |
|         inputs = text
 | |
|     else:
 | |
|         raise HTTPException(400, "Either image or text must be provided")
 | |
|     response = await run_inference(inputs, entries)
 | |
|     return ORJSONResponse(response)
 | |
| 
 | |
| 
 | |
| async def run_inference(payload: Image | str, entries: InferenceEntries) -> InferenceResponse:
 | |
|     outputs: dict[ModelIdentity, Any] = {}
 | |
|     response: InferenceResponse = {}
 | |
| 
 | |
|     async def _run_inference(entry: InferenceEntry) -> None:
 | |
|         model = await model_cache.get(entry["name"], entry["type"], entry["task"], ttl=settings.model_ttl)
 | |
|         inputs = [payload]
 | |
|         for dep in model.depends:
 | |
|             try:
 | |
|                 inputs.append(outputs[dep])
 | |
|             except KeyError:
 | |
|                 message = f"Task {entry['task']} of type {entry['type']} depends on output of {dep}"
 | |
|                 raise HTTPException(400, message)
 | |
|         model = await load(model)
 | |
|         output = await run(model.predict, *inputs, **entry["options"])
 | |
|         outputs[model.identity] = output
 | |
|         response[entry["task"]] = output
 | |
| 
 | |
|     without_deps, with_deps = entries
 | |
|     await asyncio.gather(*[_run_inference(entry) for entry in without_deps])
 | |
|     if with_deps:
 | |
|         await asyncio.gather(*[_run_inference(entry) for entry in with_deps])
 | |
|     if isinstance(payload, Image):
 | |
|         response["imageHeight"], response["imageWidth"] = payload.height, payload.width
 | |
| 
 | |
|     return response
 | |
| 
 | |
| 
 | |
| async def run(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
 | |
|     if thread_pool is None:
 | |
|         return func(*args, **kwargs)
 | |
|     partial_func = partial(func, *args, **kwargs)
 | |
|     return await asyncio.get_running_loop().run_in_executor(thread_pool, partial_func)
 | |
| 
 | |
| 
 | |
| async def load(model: InferenceModel) -> InferenceModel:
 | |
|     if model.loaded:
 | |
|         return model
 | |
| 
 | |
|     def _load(model: InferenceModel) -> InferenceModel:
 | |
|         if model.load_attempts > 1:
 | |
|             raise HTTPException(500, f"Failed to load model '{model.model_name}'")
 | |
|         with lock:
 | |
|             try:
 | |
|                 model.load()
 | |
|             except FileNotFoundError as e:
 | |
|                 if model.model_format == ModelFormat.ONNX:
 | |
|                     raise e
 | |
|                 log.warning(
 | |
|                     f"{model.model_format.upper()} is available, but model '{model.model_name}' does not support it.",
 | |
|                     exc_info=e,
 | |
|                 )
 | |
|                 model.model_format = ModelFormat.ONNX
 | |
|                 model.load()
 | |
|         return model
 | |
| 
 | |
|     try:
 | |
|         return await run(_load, model)
 | |
|     except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
 | |
|         log.warning(f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'. Clearing cache.")
 | |
|         model.clear_cache()
 | |
|         return await run(_load, model)
 | |
| 
 | |
| 
 | |
| async def idle_shutdown_task() -> None:
 | |
|     while True:
 | |
|         if (
 | |
|             last_called is not None
 | |
|             and not active_requests
 | |
|             and not lock.locked()
 | |
|             and time.time() - last_called > settings.model_ttl
 | |
|         ):
 | |
|             log.info("Shutting down due to inactivity.")
 | |
|             os.kill(os.getpid(), signal.SIGINT)
 | |
|             break
 | |
|         await asyncio.sleep(settings.model_ttl_poll_s)
 |