mirror of
https://github.com/immich-app/immich.git
synced 2025-11-01 11:07:11 -04:00
* feat: add OCR functionality and related configurations * chore: update labeler configuration for machine learning files * feat(i18n): enhance OCR model descriptions and add orientation classification and unwarping features * chore: update Dockerfile to include ccache for improved build performance * feat(ocr): enhance OCR model configuration with orientation classification and unwarping options, update PaddleOCR integration, and improve response structure * refactor(ocr): remove OCR_CLEANUP job from enum and type definitions * refactor(ocr): remove obsolete OCR entity and migration files, and update asset job status and schema to accommodate new OCR table structure * refactor(ocr): update OCR schema and response structure to use individual coordinates instead of bounding box, and adjust related service and repository files * feat: enhance OCR configuration and functionality - Updated OCR settings to include minimum detection box score, minimum detection score, and minimum recognition score. - Refactored PaddleOCRecognizer to utilize new scoring parameters. - Introduced new database tables for asset OCR data and search functionality. - Modified related services and repositories to support the new OCR features. - Updated translations for improved clarity in settings UI. * sql changes * use rapidocr * change dto * update web * update lock * update api * store positions as normalized floats * match column order in db * update admin ui settings descriptions fix max resolution key set min threshold to 0.1 fix bind * apply config correctly, adjust defaults * unnecessary model type * unnecessary sources * fix(ocr): switch RapidOCR lang type from LangDet to LangRec * fix(ocr): expose lang_type (LangRec.CH) and font_path on OcrOptions for RapidOCR * fix(ocr): make OCR text search case- and accent-insensitive using ILIKE + unaccent * fix(ocr): add OCR search fields * fix: Add OCR database migration and update ML prediction logic. * trigrams are already case insensitive * add tests * format * update migrations * wrong uuid function * linting * maybe fix medium tests * formatting * fix weblate check * openapi * sql * minor fixes * maybe fix medium tests part 2 * passing medium tests * format web * readd sql * format dart * disabled in e2e * chore: translation ordering --------- Co-authored-by: mertalev <101130780+mertalev@users.noreply.github.com> Co-authored-by: Alex Tran <alex.tran1502@gmail.com>
259 lines
8.9 KiB
Python
259 lines
8.9 KiB
Python
import asyncio
|
|
import gc
|
|
import os
|
|
import signal
|
|
import threading
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from contextlib import asynccontextmanager
|
|
from functools import partial
|
|
from typing import Any, AsyncGenerator, Callable, Iterator
|
|
from zipfile import BadZipFile
|
|
|
|
import orjson
|
|
from fastapi import Depends, FastAPI, File, Form, HTTPException
|
|
from fastapi.responses import ORJSONResponse, PlainTextResponse
|
|
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
|
|
from PIL.Image import Image
|
|
from pydantic import ValidationError
|
|
from starlette.formparsers import MultiPartParser
|
|
|
|
from immich_ml.models import get_model_deps
|
|
from immich_ml.models.base import InferenceModel
|
|
from immich_ml.models.transforms import decode_pil
|
|
|
|
from .config import PreloadModelData, log, settings
|
|
from .models.cache import ModelCache
|
|
from .schemas import (
|
|
InferenceEntries,
|
|
InferenceEntry,
|
|
InferenceResponse,
|
|
ModelFormat,
|
|
ModelIdentity,
|
|
ModelTask,
|
|
ModelType,
|
|
PipelineRequest,
|
|
T,
|
|
)
|
|
|
|
MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger
|
|
|
|
model_cache = ModelCache(revalidate=settings.model_ttl > 0)
|
|
thread_pool: ThreadPoolExecutor | None = None
|
|
lock = threading.Lock()
|
|
active_requests = 0
|
|
last_called: float | None = None
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
|
|
global thread_pool
|
|
log.info(
|
|
(
|
|
"Created in-memory cache with unloading "
|
|
f"{f'after {settings.model_ttl}s of inactivity' if settings.model_ttl > 0 else 'disabled'}."
|
|
)
|
|
)
|
|
|
|
try:
|
|
if settings.request_threads > 0:
|
|
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
|
|
thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
|
|
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
|
|
if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0:
|
|
asyncio.ensure_future(idle_shutdown_task())
|
|
if settings.preload is not None:
|
|
await preload_models(settings.preload)
|
|
yield
|
|
finally:
|
|
log.handlers.clear()
|
|
for model in model_cache.cache._cache.values():
|
|
del model
|
|
if thread_pool is not None:
|
|
thread_pool.shutdown()
|
|
gc.collect()
|
|
|
|
|
|
async def preload_models(preload: PreloadModelData) -> None:
|
|
log.info(f"Preloading models: clip:{preload.clip} facial_recognition:{preload.facial_recognition}")
|
|
|
|
async def load_models(model_string: str, model_type: ModelType, model_task: ModelTask) -> None:
|
|
for model_name in model_string.split(","):
|
|
model_name = model_name.strip()
|
|
model = await model_cache.get(model_name, model_type, model_task)
|
|
await load(model)
|
|
|
|
if preload.clip.textual is not None:
|
|
await load_models(preload.clip.textual, ModelType.TEXTUAL, ModelTask.SEARCH)
|
|
|
|
if preload.clip.visual is not None:
|
|
await load_models(preload.clip.visual, ModelType.VISUAL, ModelTask.SEARCH)
|
|
|
|
if preload.facial_recognition.detection is not None:
|
|
await load_models(
|
|
preload.facial_recognition.detection,
|
|
ModelType.DETECTION,
|
|
ModelTask.FACIAL_RECOGNITION,
|
|
)
|
|
|
|
if preload.facial_recognition.recognition is not None:
|
|
await load_models(
|
|
preload.facial_recognition.recognition,
|
|
ModelType.RECOGNITION,
|
|
ModelTask.FACIAL_RECOGNITION,
|
|
)
|
|
|
|
if preload.clip_fallback is not None:
|
|
log.warning(
|
|
"Deprecated env variable: 'MACHINE_LEARNING_PRELOAD__CLIP'. "
|
|
"Use 'MACHINE_LEARNING_PRELOAD__CLIP__TEXTUAL' and "
|
|
"'MACHINE_LEARNING_PRELOAD__CLIP__VISUAL' instead."
|
|
)
|
|
|
|
if preload.facial_recognition_fallback is not None:
|
|
log.warning(
|
|
"Deprecated env variable: 'MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION'. "
|
|
"Use 'MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__DETECTION' and "
|
|
"'MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__RECOGNITION' instead."
|
|
)
|
|
|
|
|
|
def update_state() -> Iterator[None]:
|
|
global active_requests, last_called
|
|
active_requests += 1
|
|
last_called = time.time()
|
|
try:
|
|
yield
|
|
finally:
|
|
active_requests -= 1
|
|
|
|
|
|
def get_entries(entries: str = Form()) -> InferenceEntries:
|
|
try:
|
|
request: PipelineRequest = orjson.loads(entries)
|
|
without_deps: list[InferenceEntry] = []
|
|
with_deps: list[InferenceEntry] = []
|
|
for task, types in request.items():
|
|
for type, entry in types.items():
|
|
parsed: InferenceEntry = {
|
|
"name": entry["modelName"],
|
|
"task": task,
|
|
"type": type,
|
|
"options": entry.get("options", {}),
|
|
}
|
|
dep = get_model_deps(parsed["name"], type, task)
|
|
(with_deps if dep else without_deps).append(parsed)
|
|
return without_deps, with_deps
|
|
except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e:
|
|
log.error(f"Invalid request format: {e}")
|
|
raise HTTPException(422, "Invalid request format.")
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
|
|
@app.get("/")
|
|
async def root() -> ORJSONResponse:
|
|
return ORJSONResponse({"message": "Immich ML"})
|
|
|
|
|
|
@app.get("/ping")
|
|
def ping() -> PlainTextResponse:
|
|
return PlainTextResponse("pong")
|
|
|
|
|
|
@app.post("/predict", dependencies=[Depends(update_state)])
|
|
async def predict(
|
|
entries: InferenceEntries = Depends(get_entries),
|
|
image: bytes | None = File(default=None),
|
|
text: str | None = Form(default=None),
|
|
) -> Any:
|
|
if image is not None:
|
|
inputs: Image | str = await run(lambda: decode_pil(image))
|
|
elif text is not None:
|
|
inputs = text
|
|
else:
|
|
raise HTTPException(400, "Either image or text must be provided")
|
|
response = await run_inference(inputs, entries)
|
|
return ORJSONResponse(response)
|
|
|
|
|
|
async def run_inference(payload: Image | str, entries: InferenceEntries) -> InferenceResponse:
|
|
outputs: dict[ModelIdentity, Any] = {}
|
|
response: InferenceResponse = {}
|
|
|
|
async def _run_inference(entry: InferenceEntry) -> None:
|
|
model = await model_cache.get(
|
|
entry["name"], entry["type"], entry["task"], ttl=settings.model_ttl, **entry["options"]
|
|
)
|
|
inputs = [payload]
|
|
for dep in model.depends:
|
|
try:
|
|
inputs.append(outputs[dep])
|
|
except KeyError:
|
|
message = f"Task {entry['task']} of type {entry['type']} depends on output of {dep}"
|
|
raise HTTPException(400, message)
|
|
model = await load(model)
|
|
output = await run(model.predict, *inputs, **entry["options"])
|
|
outputs[model.identity] = output
|
|
response[entry["task"]] = output
|
|
|
|
without_deps, with_deps = entries
|
|
await asyncio.gather(*[_run_inference(entry) for entry in without_deps])
|
|
if with_deps:
|
|
await asyncio.gather(*[_run_inference(entry) for entry in with_deps])
|
|
if isinstance(payload, Image):
|
|
response["imageHeight"], response["imageWidth"] = payload.height, payload.width
|
|
|
|
return response
|
|
|
|
|
|
async def run(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
|
|
if thread_pool is None:
|
|
return func(*args, **kwargs)
|
|
partial_func = partial(func, *args, **kwargs)
|
|
return await asyncio.get_running_loop().run_in_executor(thread_pool, partial_func)
|
|
|
|
|
|
async def load(model: InferenceModel) -> InferenceModel:
|
|
if model.loaded:
|
|
return model
|
|
|
|
def _load(model: InferenceModel) -> InferenceModel:
|
|
if model.load_attempts > 1:
|
|
raise HTTPException(500, f"Failed to load model '{model.model_name}'")
|
|
with lock:
|
|
try:
|
|
model.load()
|
|
except FileNotFoundError as e:
|
|
if model.model_format == ModelFormat.ONNX:
|
|
raise e
|
|
log.warning(
|
|
f"{model.model_format.upper()} is available, but model '{model.model_name}' does not support it.",
|
|
exc_info=e,
|
|
)
|
|
model.model_format = ModelFormat.ONNX
|
|
model.load()
|
|
return model
|
|
|
|
try:
|
|
return await run(_load, model)
|
|
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
|
|
log.warning(f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'. Clearing cache.")
|
|
model.clear_cache()
|
|
return await run(_load, model)
|
|
|
|
|
|
async def idle_shutdown_task() -> None:
|
|
while True:
|
|
if (
|
|
last_called is not None
|
|
and not active_requests
|
|
and not lock.locked()
|
|
and time.time() - last_called > settings.model_ttl
|
|
):
|
|
log.info("Shutting down due to inactivity.")
|
|
os.kill(os.getpid(), signal.SIGINT)
|
|
break
|
|
await asyncio.sleep(settings.model_ttl_poll_s)
|