From 35a521c6ec95eff2a4b5d8808bfb38fec6203a4a Mon Sep 17 00:00:00 2001 From: Mert <101130780+mertalev@users.noreply.github.com> Date: Thu, 5 Mar 2026 12:01:47 -0500 Subject: [PATCH] fix(ml): batch size setting (#26524) --- docs/docs/install/environment-variables.md | 2 + machine-learning/immich_ml/config.py | 5 +- .../models/facial_recognition/recognition.py | 2 +- .../immich_ml/models/ocr/detection.py | 4 +- .../immich_ml/models/ocr/recognition.py | 7 +- machine-learning/test_main.py | 77 ++++++++++++++++++- 6 files changed, 89 insertions(+), 8 deletions(-) diff --git a/docs/docs/install/environment-variables.md b/docs/docs/install/environment-variables.md index 07b37f0e41..e9e3bb032c 100644 --- a/docs/docs/install/environment-variables.md +++ b/docs/docs/install/environment-variables.md @@ -166,6 +166,8 @@ Redis (Sentinel) URL example JSON before encoding: | `MACHINE_LEARNING_PRELOAD__CLIP__VISUAL` | Comma-separated list of (visual) CLIP model(s) to preload and cache | | machine learning | | `MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__RECOGNITION` | Comma-separated list of (recognition) facial recognition model(s) to preload and cache | | machine learning | | `MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__DETECTION` | Comma-separated list of (detection) facial recognition model(s) to preload and cache | | machine learning | +| `MACHINE_LEARNING_PRELOAD__OCR__RECOGNITION` | Comma-separated list of (recognition) OCR model(s) to preload and cache | | machine learning | +| `MACHINE_LEARNING_PRELOAD__OCR__DETECTION` | Comma-separated list of (detection) OCR model(s) to preload and cache | | machine learning | | `MACHINE_LEARNING_ANN` | Enable ARM-NN hardware acceleration if supported | `True` | machine learning | | `MACHINE_LEARNING_ANN_FP16_TURBO` | Execute operations in FP16 precision: increasing speed, reducing precision (applies only to ARM-NN) | `False` | machine learning | | `MACHINE_LEARNING_ANN_TUNING_LEVEL` | ARM-NN GPU tuning level (1: rapid, 2: normal, 3: exhaustive) | `2` | machine learning | diff --git a/machine-learning/immich_ml/config.py b/machine-learning/immich_ml/config.py index 08dca04a4d..8b383f5419 100644 --- a/machine-learning/immich_ml/config.py +++ b/machine-learning/immich_ml/config.py @@ -48,8 +48,11 @@ class PreloadModelData(BaseModel): class MaxBatchSize(BaseModel): + ocr_fallback: str | None = os.getenv("MACHINE_LEARNING_MAX_BATCH_SIZE__TEXT_RECOGNITION", None) + if ocr_fallback is not None: + os.environ["MACHINE_LEARNING_MAX_BATCH_SIZE__OCR"] = ocr_fallback facial_recognition: int | None = None - text_recognition: int | None = None + ocr: int | None = None class Settings(BaseSettings): diff --git a/machine-learning/immich_ml/models/facial_recognition/recognition.py b/machine-learning/immich_ml/models/facial_recognition/recognition.py index 759992a600..ed1897c9f9 100644 --- a/machine-learning/immich_ml/models/facial_recognition/recognition.py +++ b/machine-learning/immich_ml/models/facial_recognition/recognition.py @@ -29,7 +29,7 @@ class FaceRecognizer(InferenceModel): def __init__(self, model_name: str, **model_kwargs: Any) -> None: super().__init__(model_name, **model_kwargs) - max_batch_size = settings.max_batch_size.facial_recognition if settings.max_batch_size else None + max_batch_size = settings.max_batch_size and settings.max_batch_size.facial_recognition self.batch_size = max_batch_size if max_batch_size else self._batch_size_default def _load(self) -> ModelSession: diff --git a/machine-learning/immich_ml/models/ocr/detection.py b/machine-learning/immich_ml/models/ocr/detection.py index d34a51684e..0a2cb8ad91 100644 --- a/machine-learning/immich_ml/models/ocr/detection.py +++ b/machine-learning/immich_ml/models/ocr/detection.py @@ -22,7 +22,7 @@ class TextDetector(InferenceModel): depends = [] identity = (ModelType.DETECTION, ModelTask.OCR) - def __init__(self, model_name: str, **model_kwargs: Any) -> None: + def __init__(self, model_name: str, min_score: float = 0.5, **model_kwargs: Any) -> None: super().__init__(model_name.split("__")[-1], **model_kwargs, model_format=ModelFormat.ONNX) self.max_resolution = 736 self.mean = np.array([0.5, 0.5, 0.5], dtype=np.float32) @@ -33,7 +33,7 @@ class TextDetector(InferenceModel): } self.postprocess = DBPostProcess( thresh=0.3, - box_thresh=model_kwargs.get("minScore", 0.5), + box_thresh=model_kwargs.get("minScore", min_score), max_candidates=1000, unclip_ratio=1.6, use_dilation=True, diff --git a/machine-learning/immich_ml/models/ocr/recognition.py b/machine-learning/immich_ml/models/ocr/recognition.py index e968392881..6408e4818f 100644 --- a/machine-learning/immich_ml/models/ocr/recognition.py +++ b/machine-learning/immich_ml/models/ocr/recognition.py @@ -24,9 +24,9 @@ class TextRecognizer(InferenceModel): depends = [(ModelType.DETECTION, ModelTask.OCR)] identity = (ModelType.RECOGNITION, ModelTask.OCR) - def __init__(self, model_name: str, **model_kwargs: Any) -> None: + def __init__(self, model_name: str, min_score: float = 0.9, **model_kwargs: Any) -> None: self.language = LangRec[model_name.split("__")[0]] if "__" in model_name else LangRec.CH - self.min_score = model_kwargs.get("minScore", 0.9) + self.min_score = model_kwargs.get("minScore", min_score) self._empty: TextRecognitionOutput = { "box": np.empty(0, dtype=np.float32), "boxScore": np.empty(0, dtype=np.float32), @@ -57,10 +57,11 @@ class TextRecognizer(InferenceModel): def _load(self) -> ModelSession: # TODO: support other runtimes session = OrtSession(self.model_path) + max_batch_size = settings.max_batch_size and settings.max_batch_size.ocr self.model = RapidTextRecognizer( OcrOptions( session=session.session, - rec_batch_num=settings.max_batch_size.text_recognition if settings.max_batch_size is not None else 6, + rec_batch_num=max_batch_size if max_batch_size else 6, rec_img_shape=(3, 48, 320), lang_type=self.language, ) diff --git a/machine-learning/test_main.py b/machine-learning/test_main.py index f37880610a..a5cf1acc2e 100644 --- a/machine-learning/test_main.py +++ b/machine-learning/test_main.py @@ -18,7 +18,7 @@ from PIL import Image from pytest import MonkeyPatch from pytest_mock import MockerFixture -from immich_ml.config import Settings, settings +from immich_ml.config import MaxBatchSize, Settings, settings from immich_ml.main import load, preload_models from immich_ml.models.base import InferenceModel from immich_ml.models.cache import ModelCache @@ -26,6 +26,9 @@ from immich_ml.models.clip.textual import MClipTextualEncoder, OpenClipTextualEn from immich_ml.models.clip.visual import OpenClipVisualEncoder from immich_ml.models.facial_recognition.detection import FaceDetector from immich_ml.models.facial_recognition.recognition import FaceRecognizer +from immich_ml.models.ocr.detection import TextDetector +from immich_ml.models.ocr.recognition import TextRecognizer +from immich_ml.models.ocr.schemas import OcrOptions from immich_ml.schemas import ModelFormat, ModelPrecision, ModelTask, ModelType from immich_ml.sessions.ann import AnnSession from immich_ml.sessions.ort import OrtSession @@ -855,6 +858,78 @@ class TestFaceRecognition: onnx.load.assert_not_called() onnx.save.assert_not_called() + def test_set_custom_max_batch_size(self, mocker: MockerFixture) -> None: + mocker.patch.object(settings, "max_batch_size", MaxBatchSize(facial_recognition=2)) + + recognizer = FaceRecognizer("buffalo_l", cache_dir="test_cache") + + assert recognizer.batch_size == 2 + + def test_ignore_other_custom_max_batch_size(self, mocker: MockerFixture) -> None: + mocker.patch.object(settings, "max_batch_size", MaxBatchSize(ocr=2)) + + recognizer = FaceRecognizer("buffalo_l", cache_dir="test_cache") + + assert recognizer.batch_size is None + + +class TestOcr: + def test_set_det_min_score(self, path: mock.Mock) -> None: + path.return_value.__truediv__.return_value.__truediv__.return_value.suffix = ".onnx" + + text_detector = TextDetector("PP-OCRv5_mobile", min_score=0.8, cache_dir="test_cache") + + assert text_detector.postprocess.box_thresh == 0.8 + + def test_set_rec_min_score(self, path: mock.Mock) -> None: + path.return_value.__truediv__.return_value.__truediv__.return_value.suffix = ".onnx" + + text_recognizer = TextRecognizer("PP-OCRv5_mobile", min_score=0.8, cache_dir="test_cache") + + assert text_recognizer.min_score == 0.8 + + def test_set_rec_set_default_max_batch_size( + self, ort_session: mock.Mock, path: mock.Mock, mocker: MockerFixture + ) -> None: + path.return_value.__truediv__.return_value.__truediv__.return_value.suffix = ".onnx" + mocker.patch("immich_ml.models.base.InferenceModel.download") + rapid_recognizer = mocker.patch("immich_ml.models.ocr.recognition.RapidTextRecognizer") + + text_recognizer = TextRecognizer("PP-OCRv5_mobile", cache_dir="test_cache") + text_recognizer.load() + + rapid_recognizer.assert_called_once_with( + OcrOptions(session=ort_session.return_value, rec_batch_num=6, rec_img_shape=(3, 48, 320)) + ) + + def test_set_custom_max_batch_size(self, ort_session: mock.Mock, path: mock.Mock, mocker: MockerFixture) -> None: + path.return_value.__truediv__.return_value.__truediv__.return_value.suffix = ".onnx" + mocker.patch("immich_ml.models.base.InferenceModel.download") + rapid_recognizer = mocker.patch("immich_ml.models.ocr.recognition.RapidTextRecognizer") + mocker.patch.object(settings, "max_batch_size", MaxBatchSize(ocr=4)) + + text_recognizer = TextRecognizer("PP-OCRv5_mobile", cache_dir="test_cache") + text_recognizer.load() + + rapid_recognizer.assert_called_once_with( + OcrOptions(session=ort_session.return_value, rec_batch_num=4, rec_img_shape=(3, 48, 320)) + ) + + def test_ignore_other_custom_max_batch_size( + self, ort_session: mock.Mock, path: mock.Mock, mocker: MockerFixture + ) -> None: + path.return_value.__truediv__.return_value.__truediv__.return_value.suffix = ".onnx" + mocker.patch("immich_ml.models.base.InferenceModel.download") + rapid_recognizer = mocker.patch("immich_ml.models.ocr.recognition.RapidTextRecognizer") + mocker.patch.object(settings, "max_batch_size", MaxBatchSize(facial_recognition=3)) + + text_recognizer = TextRecognizer("PP-OCRv5_mobile", cache_dir="test_cache") + text_recognizer.load() + + rapid_recognizer.assert_called_once_with( + OcrOptions(session=ort_session.return_value, rec_batch_num=6, rec_img_shape=(3, 48, 320)) + ) + @pytest.mark.asyncio class TestCache: