mirror of
				https://github.com/immich-app/immich.git
				synced 2025-10-31 02:39:03 -04:00 
			
		
		
		
	* modularize model classes * various fixes * expose port * change response * round coordinates * simplify preload * update server * simplify interface simplify * update tests * composable endpoint * cleanup fixes remove unnecessary interface support text input, cleanup * ew camelcase * update server server fixes fix typing * ml fixes update locustfile fixes * cleaner response * better repo response * update tests formatting and typing rename * undo compose change * linting fix type actually fix typing * stricter typing fix detection-only response no need for defaultdict * update spec file update api linting * update e2e * unnecessary dimension * remove commented code * remove duplicate code * remove unused imports * add batch dim
		
			
				
	
	
		
			227 lines
		
	
	
		
			7.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			227 lines
		
	
	
		
			7.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import asyncio
 | |
| import gc
 | |
| import os
 | |
| import signal
 | |
| import threading
 | |
| import time
 | |
| from concurrent.futures import ThreadPoolExecutor
 | |
| from contextlib import asynccontextmanager
 | |
| from functools import partial
 | |
| from typing import Any, AsyncGenerator, Callable, Iterator
 | |
| from zipfile import BadZipFile
 | |
| 
 | |
| import orjson
 | |
| from fastapi import Depends, FastAPI, File, Form, HTTPException
 | |
| from fastapi.responses import ORJSONResponse
 | |
| from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
 | |
| from PIL.Image import Image
 | |
| from pydantic import ValidationError
 | |
| from starlette.formparsers import MultiPartParser
 | |
| 
 | |
| from app.models import get_model_deps
 | |
| from app.models.base import InferenceModel
 | |
| from app.models.transforms import decode_pil
 | |
| 
 | |
| from .config import PreloadModelData, log, settings
 | |
| from .models.cache import ModelCache
 | |
| from .schemas import (
 | |
|     InferenceEntries,
 | |
|     InferenceEntry,
 | |
|     InferenceResponse,
 | |
|     MessageResponse,
 | |
|     ModelIdentity,
 | |
|     ModelTask,
 | |
|     ModelType,
 | |
|     PipelineRequest,
 | |
|     T,
 | |
|     TextResponse,
 | |
| )
 | |
| 
 | |
| MultiPartParser.max_file_size = 2**26  # spools to disk if payload is 64 MiB or larger
 | |
| 
 | |
| model_cache = ModelCache(revalidate=settings.model_ttl > 0)
 | |
| thread_pool: ThreadPoolExecutor | None = None
 | |
| lock = threading.Lock()
 | |
| active_requests = 0
 | |
| last_called: float | None = None
 | |
| 
 | |
| 
 | |
| @asynccontextmanager
 | |
| async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
 | |
|     global thread_pool
 | |
|     log.info(
 | |
|         (
 | |
|             "Created in-memory cache with unloading "
 | |
|             f"{f'after {settings.model_ttl}s of inactivity' if settings.model_ttl > 0 else 'disabled'}."
 | |
|         )
 | |
|     )
 | |
| 
 | |
|     try:
 | |
|         if settings.request_threads > 0:
 | |
|             # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
 | |
|             thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
 | |
|             log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
 | |
|         if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0:
 | |
|             asyncio.ensure_future(idle_shutdown_task())
 | |
|         if settings.preload is not None:
 | |
|             await preload_models(settings.preload)
 | |
|         yield
 | |
|     finally:
 | |
|         log.handlers.clear()
 | |
|         for model in model_cache.cache._cache.values():
 | |
|             del model
 | |
|         if thread_pool is not None:
 | |
|             thread_pool.shutdown()
 | |
|         gc.collect()
 | |
| 
 | |
| 
 | |
| async def preload_models(preload: PreloadModelData) -> None:
 | |
|     log.info(f"Preloading models: {preload}")
 | |
|     if preload.clip is not None:
 | |
|         model = await model_cache.get(preload.clip, ModelType.TEXTUAL, ModelTask.SEARCH)
 | |
|         await load(model)
 | |
| 
 | |
|         model = await model_cache.get(preload.clip, ModelType.VISUAL, ModelTask.SEARCH)
 | |
|         await load(model)
 | |
| 
 | |
|     if preload.facial_recognition is not None:
 | |
|         model = await model_cache.get(preload.facial_recognition, ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)
 | |
|         await load(model)
 | |
| 
 | |
|         model = await model_cache.get(preload.facial_recognition, ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION)
 | |
|         await load(model)
 | |
| 
 | |
| 
 | |
| def update_state() -> Iterator[None]:
 | |
|     global active_requests, last_called
 | |
|     active_requests += 1
 | |
|     last_called = time.time()
 | |
|     try:
 | |
|         yield
 | |
|     finally:
 | |
|         active_requests -= 1
 | |
| 
 | |
| 
 | |
| def get_entries(entries: str = Form()) -> InferenceEntries:
 | |
|     try:
 | |
|         request: PipelineRequest = orjson.loads(entries)
 | |
|         without_deps: list[InferenceEntry] = []
 | |
|         with_deps: list[InferenceEntry] = []
 | |
|         for task, types in request.items():
 | |
|             for type, entry in types.items():
 | |
|                 parsed: InferenceEntry = {
 | |
|                     "name": entry["modelName"],
 | |
|                     "task": task,
 | |
|                     "type": type,
 | |
|                     "options": entry.get("options", {}),
 | |
|                 }
 | |
|                 dep = get_model_deps(parsed["name"], type, task)
 | |
|                 (with_deps if dep else without_deps).append(parsed)
 | |
|         return without_deps, with_deps
 | |
|     except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e:
 | |
|         log.error(f"Invalid request format: {e}")
 | |
|         raise HTTPException(422, "Invalid request format.")
 | |
| 
 | |
| 
 | |
| app = FastAPI(lifespan=lifespan)
 | |
| 
 | |
| 
 | |
| @app.get("/", response_model=MessageResponse)
 | |
| async def root() -> dict[str, str]:
 | |
|     return {"message": "Immich ML"}
 | |
| 
 | |
| 
 | |
| @app.get("/ping", response_model=TextResponse)
 | |
| def ping() -> str:
 | |
|     return "pong"
 | |
| 
 | |
| 
 | |
| @app.post("/predict", dependencies=[Depends(update_state)])
 | |
| async def predict(
 | |
|     entries: InferenceEntries = Depends(get_entries),
 | |
|     image: bytes | None = File(default=None),
 | |
|     text: str | None = Form(default=None),
 | |
| ) -> Any:
 | |
|     if image is not None:
 | |
|         inputs: Image | str = await run(lambda: decode_pil(image))
 | |
|     elif text is not None:
 | |
|         inputs = text
 | |
|     else:
 | |
|         raise HTTPException(400, "Either image or text must be provided")
 | |
|     response = await run_inference(inputs, entries)
 | |
|     return ORJSONResponse(response)
 | |
| 
 | |
| 
 | |
| async def run_inference(payload: Image | str, entries: InferenceEntries) -> InferenceResponse:
 | |
|     outputs: dict[ModelIdentity, Any] = {}
 | |
|     response: InferenceResponse = {}
 | |
| 
 | |
|     async def _run_inference(entry: InferenceEntry) -> None:
 | |
|         model = await model_cache.get(entry["name"], entry["type"], entry["task"], ttl=settings.model_ttl)
 | |
|         inputs = [payload]
 | |
|         for dep in model.depends:
 | |
|             try:
 | |
|                 inputs.append(outputs[dep])
 | |
|             except KeyError:
 | |
|                 message = f"Task {entry['task']} of type {entry['type']} depends on output of {dep}"
 | |
|                 raise HTTPException(400, message)
 | |
|         model = await load(model)
 | |
|         output = await run(model.predict, *inputs, **entry["options"])
 | |
|         outputs[model.identity] = output
 | |
|         response[entry["task"]] = output
 | |
| 
 | |
|     without_deps, with_deps = entries
 | |
|     await asyncio.gather(*[_run_inference(entry) for entry in without_deps])
 | |
|     if with_deps:
 | |
|         await asyncio.gather(*[_run_inference(entry) for entry in with_deps])
 | |
|     if isinstance(payload, Image):
 | |
|         response["imageHeight"], response["imageWidth"] = payload.height, payload.width
 | |
| 
 | |
|     return response
 | |
| 
 | |
| 
 | |
| async def run(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
 | |
|     if thread_pool is None:
 | |
|         return func(*args, **kwargs)
 | |
|     partial_func = partial(func, *args, **kwargs)
 | |
|     return await asyncio.get_running_loop().run_in_executor(thread_pool, partial_func)
 | |
| 
 | |
| 
 | |
| async def load(model: InferenceModel) -> InferenceModel:
 | |
|     if model.loaded:
 | |
|         return model
 | |
| 
 | |
|     def _load(model: InferenceModel) -> InferenceModel:
 | |
|         with lock:
 | |
|             model.load()
 | |
|         return model
 | |
| 
 | |
|     try:
 | |
|         await run(_load, model)
 | |
|         return model
 | |
|     except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
 | |
|         log.warning(
 | |
|             (
 | |
|                 f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'."
 | |
|                 "Clearing cache and retrying."
 | |
|             )
 | |
|         )
 | |
|         model.clear_cache()
 | |
|         await run(_load, model)
 | |
|         return model
 | |
| 
 | |
| 
 | |
| async def idle_shutdown_task() -> None:
 | |
|     while True:
 | |
|         log.debug("Checking for inactivity...")
 | |
|         if (
 | |
|             last_called is not None
 | |
|             and not active_requests
 | |
|             and not lock.locked()
 | |
|             and time.time() - last_called > settings.model_ttl
 | |
|         ):
 | |
|             log.info("Shutting down due to inactivity.")
 | |
|             os.kill(os.getpid(), signal.SIGINT)
 | |
|             break
 | |
|         await asyncio.sleep(settings.model_ttl_poll_s)
 |