mirror of
				https://github.com/immich-app/immich.git
				synced 2025-11-03 11:07:10 -05:00 
			
		
		
		
	
		
			
				
	
	
		
			257 lines
		
	
	
		
			8.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			257 lines
		
	
	
		
			8.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import asyncio
 | 
						|
import gc
 | 
						|
import os
 | 
						|
import signal
 | 
						|
import threading
 | 
						|
import time
 | 
						|
from concurrent.futures import ThreadPoolExecutor
 | 
						|
from contextlib import asynccontextmanager
 | 
						|
from functools import partial
 | 
						|
from typing import Any, AsyncGenerator, Callable, Iterator
 | 
						|
from zipfile import BadZipFile
 | 
						|
 | 
						|
import orjson
 | 
						|
from fastapi import Depends, FastAPI, File, Form, HTTPException
 | 
						|
from fastapi.responses import ORJSONResponse, PlainTextResponse
 | 
						|
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
 | 
						|
from PIL.Image import Image
 | 
						|
from pydantic import ValidationError
 | 
						|
from starlette.formparsers import MultiPartParser
 | 
						|
 | 
						|
from immich_ml.models import get_model_deps
 | 
						|
from immich_ml.models.base import InferenceModel
 | 
						|
from immich_ml.models.transforms import decode_pil
 | 
						|
 | 
						|
from .config import PreloadModelData, log, settings
 | 
						|
from .models.cache import ModelCache
 | 
						|
from .schemas import (
 | 
						|
    InferenceEntries,
 | 
						|
    InferenceEntry,
 | 
						|
    InferenceResponse,
 | 
						|
    ModelFormat,
 | 
						|
    ModelIdentity,
 | 
						|
    ModelTask,
 | 
						|
    ModelType,
 | 
						|
    PipelineRequest,
 | 
						|
    T,
 | 
						|
)
 | 
						|
 | 
						|
MultiPartParser.max_file_size = 2**26  # spools to disk if payload is 64 MiB or larger
 | 
						|
 | 
						|
model_cache = ModelCache(revalidate=settings.model_ttl > 0)
 | 
						|
thread_pool: ThreadPoolExecutor | None = None
 | 
						|
lock = threading.Lock()
 | 
						|
active_requests = 0
 | 
						|
last_called: float | None = None
 | 
						|
 | 
						|
 | 
						|
@asynccontextmanager
 | 
						|
async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
 | 
						|
    global thread_pool
 | 
						|
    log.info(
 | 
						|
        (
 | 
						|
            "Created in-memory cache with unloading "
 | 
						|
            f"{f'after {settings.model_ttl}s of inactivity' if settings.model_ttl > 0 else 'disabled'}."
 | 
						|
        )
 | 
						|
    )
 | 
						|
 | 
						|
    try:
 | 
						|
        if settings.request_threads > 0:
 | 
						|
            # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
 | 
						|
            thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
 | 
						|
            log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
 | 
						|
        if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0:
 | 
						|
            asyncio.ensure_future(idle_shutdown_task())
 | 
						|
        if settings.preload is not None:
 | 
						|
            await preload_models(settings.preload)
 | 
						|
        yield
 | 
						|
    finally:
 | 
						|
        log.handlers.clear()
 | 
						|
        for model in model_cache.cache._cache.values():
 | 
						|
            del model
 | 
						|
        if thread_pool is not None:
 | 
						|
            thread_pool.shutdown()
 | 
						|
        gc.collect()
 | 
						|
 | 
						|
 | 
						|
async def preload_models(preload: PreloadModelData) -> None:
 | 
						|
    log.info(f"Preloading models: clip:{preload.clip} facial_recognition:{preload.facial_recognition}")
 | 
						|
 | 
						|
    async def load_models(model_string: str, model_type: ModelType, model_task: ModelTask) -> None:
 | 
						|
        for model_name in model_string.split(","):
 | 
						|
            model_name = model_name.strip()
 | 
						|
            model = await model_cache.get(model_name, model_type, model_task)
 | 
						|
            await load(model)
 | 
						|
 | 
						|
    if preload.clip.textual is not None:
 | 
						|
        await load_models(preload.clip.textual, ModelType.TEXTUAL, ModelTask.SEARCH)
 | 
						|
 | 
						|
    if preload.clip.visual is not None:
 | 
						|
        await load_models(preload.clip.visual, ModelType.VISUAL, ModelTask.SEARCH)
 | 
						|
 | 
						|
    if preload.facial_recognition.detection is not None:
 | 
						|
        await load_models(
 | 
						|
            preload.facial_recognition.detection,
 | 
						|
            ModelType.DETECTION,
 | 
						|
            ModelTask.FACIAL_RECOGNITION,
 | 
						|
        )
 | 
						|
 | 
						|
    if preload.facial_recognition.recognition is not None:
 | 
						|
        await load_models(
 | 
						|
            preload.facial_recognition.recognition,
 | 
						|
            ModelType.RECOGNITION,
 | 
						|
            ModelTask.FACIAL_RECOGNITION,
 | 
						|
        )
 | 
						|
 | 
						|
    if preload.clip_fallback is not None:
 | 
						|
        log.warning(
 | 
						|
            "Deprecated env variable: 'MACHINE_LEARNING_PRELOAD__CLIP'. "
 | 
						|
            "Use 'MACHINE_LEARNING_PRELOAD__CLIP__TEXTUAL' and "
 | 
						|
            "'MACHINE_LEARNING_PRELOAD__CLIP__VISUAL' instead."
 | 
						|
        )
 | 
						|
 | 
						|
    if preload.facial_recognition_fallback is not None:
 | 
						|
        log.warning(
 | 
						|
            "Deprecated env variable: 'MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION'. "
 | 
						|
            "Use 'MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__DETECTION' and "
 | 
						|
            "'MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__RECOGNITION' instead."
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
def update_state() -> Iterator[None]:
 | 
						|
    global active_requests, last_called
 | 
						|
    active_requests += 1
 | 
						|
    last_called = time.time()
 | 
						|
    try:
 | 
						|
        yield
 | 
						|
    finally:
 | 
						|
        active_requests -= 1
 | 
						|
 | 
						|
 | 
						|
def get_entries(entries: str = Form()) -> InferenceEntries:
 | 
						|
    try:
 | 
						|
        request: PipelineRequest = orjson.loads(entries)
 | 
						|
        without_deps: list[InferenceEntry] = []
 | 
						|
        with_deps: list[InferenceEntry] = []
 | 
						|
        for task, types in request.items():
 | 
						|
            for type, entry in types.items():
 | 
						|
                parsed: InferenceEntry = {
 | 
						|
                    "name": entry["modelName"],
 | 
						|
                    "task": task,
 | 
						|
                    "type": type,
 | 
						|
                    "options": entry.get("options", {}),
 | 
						|
                }
 | 
						|
                dep = get_model_deps(parsed["name"], type, task)
 | 
						|
                (with_deps if dep else without_deps).append(parsed)
 | 
						|
        return without_deps, with_deps
 | 
						|
    except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e:
 | 
						|
        log.error(f"Invalid request format: {e}")
 | 
						|
        raise HTTPException(422, "Invalid request format.")
 | 
						|
 | 
						|
 | 
						|
app = FastAPI(lifespan=lifespan)
 | 
						|
 | 
						|
 | 
						|
@app.get("/")
 | 
						|
async def root() -> ORJSONResponse:
 | 
						|
    return ORJSONResponse({"message": "Immich ML"})
 | 
						|
 | 
						|
 | 
						|
@app.get("/ping")
 | 
						|
def ping() -> PlainTextResponse:
 | 
						|
    return PlainTextResponse("pong")
 | 
						|
 | 
						|
 | 
						|
@app.post("/predict", dependencies=[Depends(update_state)])
 | 
						|
async def predict(
 | 
						|
    entries: InferenceEntries = Depends(get_entries),
 | 
						|
    image: bytes | None = File(default=None),
 | 
						|
    text: str | None = Form(default=None),
 | 
						|
) -> Any:
 | 
						|
    if image is not None:
 | 
						|
        inputs: Image | str = await run(lambda: decode_pil(image))
 | 
						|
    elif text is not None:
 | 
						|
        inputs = text
 | 
						|
    else:
 | 
						|
        raise HTTPException(400, "Either image or text must be provided")
 | 
						|
    response = await run_inference(inputs, entries)
 | 
						|
    return ORJSONResponse(response)
 | 
						|
 | 
						|
 | 
						|
async def run_inference(payload: Image | str, entries: InferenceEntries) -> InferenceResponse:
 | 
						|
    outputs: dict[ModelIdentity, Any] = {}
 | 
						|
    response: InferenceResponse = {}
 | 
						|
 | 
						|
    async def _run_inference(entry: InferenceEntry) -> None:
 | 
						|
        model = await model_cache.get(entry["name"], entry["type"], entry["task"], ttl=settings.model_ttl)
 | 
						|
        inputs = [payload]
 | 
						|
        for dep in model.depends:
 | 
						|
            try:
 | 
						|
                inputs.append(outputs[dep])
 | 
						|
            except KeyError:
 | 
						|
                message = f"Task {entry['task']} of type {entry['type']} depends on output of {dep}"
 | 
						|
                raise HTTPException(400, message)
 | 
						|
        model = await load(model)
 | 
						|
        output = await run(model.predict, *inputs, **entry["options"])
 | 
						|
        outputs[model.identity] = output
 | 
						|
        response[entry["task"]] = output
 | 
						|
 | 
						|
    without_deps, with_deps = entries
 | 
						|
    await asyncio.gather(*[_run_inference(entry) for entry in without_deps])
 | 
						|
    if with_deps:
 | 
						|
        await asyncio.gather(*[_run_inference(entry) for entry in with_deps])
 | 
						|
    if isinstance(payload, Image):
 | 
						|
        response["imageHeight"], response["imageWidth"] = payload.height, payload.width
 | 
						|
 | 
						|
    return response
 | 
						|
 | 
						|
 | 
						|
async def run(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
 | 
						|
    if thread_pool is None:
 | 
						|
        return func(*args, **kwargs)
 | 
						|
    partial_func = partial(func, *args, **kwargs)
 | 
						|
    return await asyncio.get_running_loop().run_in_executor(thread_pool, partial_func)
 | 
						|
 | 
						|
 | 
						|
async def load(model: InferenceModel) -> InferenceModel:
 | 
						|
    if model.loaded:
 | 
						|
        return model
 | 
						|
 | 
						|
    def _load(model: InferenceModel) -> InferenceModel:
 | 
						|
        if model.load_attempts > 1:
 | 
						|
            raise HTTPException(500, f"Failed to load model '{model.model_name}'")
 | 
						|
        with lock:
 | 
						|
            try:
 | 
						|
                model.load()
 | 
						|
            except FileNotFoundError as e:
 | 
						|
                if model.model_format == ModelFormat.ONNX:
 | 
						|
                    raise e
 | 
						|
                log.warning(
 | 
						|
                    f"{model.model_format.upper()} is available, but model '{model.model_name}' does not support it.",
 | 
						|
                    exc_info=e,
 | 
						|
                )
 | 
						|
                model.model_format = ModelFormat.ONNX
 | 
						|
                model.load()
 | 
						|
        return model
 | 
						|
 | 
						|
    try:
 | 
						|
        return await run(_load, model)
 | 
						|
    except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
 | 
						|
        log.warning(f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'. Clearing cache.")
 | 
						|
        model.clear_cache()
 | 
						|
        return await run(_load, model)
 | 
						|
 | 
						|
 | 
						|
async def idle_shutdown_task() -> None:
 | 
						|
    while True:
 | 
						|
        if (
 | 
						|
            last_called is not None
 | 
						|
            and not active_requests
 | 
						|
            and not lock.locked()
 | 
						|
            and time.time() - last_called > settings.model_ttl
 | 
						|
        ):
 | 
						|
            log.info("Shutting down due to inactivity.")
 | 
						|
            os.kill(os.getpid(), signal.SIGINT)
 | 
						|
            break
 | 
						|
        await asyncio.sleep(settings.model_ttl_poll_s)
 |