mirror of
https://github.com/immich-app/immich.git
synced 2026-03-09 11:23:46 -04:00
fix(ml): batch size setting (#26524)
This commit is contained in:
parent
09fabb36b6
commit
35a521c6ec
@ -166,6 +166,8 @@ Redis (Sentinel) URL example JSON before encoding:
|
||||
| `MACHINE_LEARNING_PRELOAD__CLIP__VISUAL` | Comma-separated list of (visual) CLIP model(s) to preload and cache | | machine learning |
|
||||
| `MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__RECOGNITION` | Comma-separated list of (recognition) facial recognition model(s) to preload and cache | | machine learning |
|
||||
| `MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__DETECTION` | Comma-separated list of (detection) facial recognition model(s) to preload and cache | | machine learning |
|
||||
| `MACHINE_LEARNING_PRELOAD__OCR__RECOGNITION` | Comma-separated list of (recognition) OCR model(s) to preload and cache | | machine learning |
|
||||
| `MACHINE_LEARNING_PRELOAD__OCR__DETECTION` | Comma-separated list of (detection) OCR model(s) to preload and cache | | machine learning |
|
||||
| `MACHINE_LEARNING_ANN` | Enable ARM-NN hardware acceleration if supported | `True` | machine learning |
|
||||
| `MACHINE_LEARNING_ANN_FP16_TURBO` | Execute operations in FP16 precision: increasing speed, reducing precision (applies only to ARM-NN) | `False` | machine learning |
|
||||
| `MACHINE_LEARNING_ANN_TUNING_LEVEL` | ARM-NN GPU tuning level (1: rapid, 2: normal, 3: exhaustive) | `2` | machine learning |
|
||||
|
||||
@ -48,8 +48,11 @@ class PreloadModelData(BaseModel):
|
||||
|
||||
|
||||
class MaxBatchSize(BaseModel):
|
||||
ocr_fallback: str | None = os.getenv("MACHINE_LEARNING_MAX_BATCH_SIZE__TEXT_RECOGNITION", None)
|
||||
if ocr_fallback is not None:
|
||||
os.environ["MACHINE_LEARNING_MAX_BATCH_SIZE__OCR"] = ocr_fallback
|
||||
facial_recognition: int | None = None
|
||||
text_recognition: int | None = None
|
||||
ocr: int | None = None
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
|
||||
@ -29,7 +29,7 @@ class FaceRecognizer(InferenceModel):
|
||||
|
||||
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
|
||||
super().__init__(model_name, **model_kwargs)
|
||||
max_batch_size = settings.max_batch_size.facial_recognition if settings.max_batch_size else None
|
||||
max_batch_size = settings.max_batch_size and settings.max_batch_size.facial_recognition
|
||||
self.batch_size = max_batch_size if max_batch_size else self._batch_size_default
|
||||
|
||||
def _load(self) -> ModelSession:
|
||||
|
||||
@ -22,7 +22,7 @@ class TextDetector(InferenceModel):
|
||||
depends = []
|
||||
identity = (ModelType.DETECTION, ModelTask.OCR)
|
||||
|
||||
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
|
||||
def __init__(self, model_name: str, min_score: float = 0.5, **model_kwargs: Any) -> None:
|
||||
super().__init__(model_name.split("__")[-1], **model_kwargs, model_format=ModelFormat.ONNX)
|
||||
self.max_resolution = 736
|
||||
self.mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
|
||||
@ -33,7 +33,7 @@ class TextDetector(InferenceModel):
|
||||
}
|
||||
self.postprocess = DBPostProcess(
|
||||
thresh=0.3,
|
||||
box_thresh=model_kwargs.get("minScore", 0.5),
|
||||
box_thresh=model_kwargs.get("minScore", min_score),
|
||||
max_candidates=1000,
|
||||
unclip_ratio=1.6,
|
||||
use_dilation=True,
|
||||
|
||||
@ -24,9 +24,9 @@ class TextRecognizer(InferenceModel):
|
||||
depends = [(ModelType.DETECTION, ModelTask.OCR)]
|
||||
identity = (ModelType.RECOGNITION, ModelTask.OCR)
|
||||
|
||||
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
|
||||
def __init__(self, model_name: str, min_score: float = 0.9, **model_kwargs: Any) -> None:
|
||||
self.language = LangRec[model_name.split("__")[0]] if "__" in model_name else LangRec.CH
|
||||
self.min_score = model_kwargs.get("minScore", 0.9)
|
||||
self.min_score = model_kwargs.get("minScore", min_score)
|
||||
self._empty: TextRecognitionOutput = {
|
||||
"box": np.empty(0, dtype=np.float32),
|
||||
"boxScore": np.empty(0, dtype=np.float32),
|
||||
@ -57,10 +57,11 @@ class TextRecognizer(InferenceModel):
|
||||
def _load(self) -> ModelSession:
|
||||
# TODO: support other runtimes
|
||||
session = OrtSession(self.model_path)
|
||||
max_batch_size = settings.max_batch_size and settings.max_batch_size.ocr
|
||||
self.model = RapidTextRecognizer(
|
||||
OcrOptions(
|
||||
session=session.session,
|
||||
rec_batch_num=settings.max_batch_size.text_recognition if settings.max_batch_size is not None else 6,
|
||||
rec_batch_num=max_batch_size if max_batch_size else 6,
|
||||
rec_img_shape=(3, 48, 320),
|
||||
lang_type=self.language,
|
||||
)
|
||||
|
||||
@ -18,7 +18,7 @@ from PIL import Image
|
||||
from pytest import MonkeyPatch
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from immich_ml.config import Settings, settings
|
||||
from immich_ml.config import MaxBatchSize, Settings, settings
|
||||
from immich_ml.main import load, preload_models
|
||||
from immich_ml.models.base import InferenceModel
|
||||
from immich_ml.models.cache import ModelCache
|
||||
@ -26,6 +26,9 @@ from immich_ml.models.clip.textual import MClipTextualEncoder, OpenClipTextualEn
|
||||
from immich_ml.models.clip.visual import OpenClipVisualEncoder
|
||||
from immich_ml.models.facial_recognition.detection import FaceDetector
|
||||
from immich_ml.models.facial_recognition.recognition import FaceRecognizer
|
||||
from immich_ml.models.ocr.detection import TextDetector
|
||||
from immich_ml.models.ocr.recognition import TextRecognizer
|
||||
from immich_ml.models.ocr.schemas import OcrOptions
|
||||
from immich_ml.schemas import ModelFormat, ModelPrecision, ModelTask, ModelType
|
||||
from immich_ml.sessions.ann import AnnSession
|
||||
from immich_ml.sessions.ort import OrtSession
|
||||
@ -855,6 +858,78 @@ class TestFaceRecognition:
|
||||
onnx.load.assert_not_called()
|
||||
onnx.save.assert_not_called()
|
||||
|
||||
def test_set_custom_max_batch_size(self, mocker: MockerFixture) -> None:
|
||||
mocker.patch.object(settings, "max_batch_size", MaxBatchSize(facial_recognition=2))
|
||||
|
||||
recognizer = FaceRecognizer("buffalo_l", cache_dir="test_cache")
|
||||
|
||||
assert recognizer.batch_size == 2
|
||||
|
||||
def test_ignore_other_custom_max_batch_size(self, mocker: MockerFixture) -> None:
|
||||
mocker.patch.object(settings, "max_batch_size", MaxBatchSize(ocr=2))
|
||||
|
||||
recognizer = FaceRecognizer("buffalo_l", cache_dir="test_cache")
|
||||
|
||||
assert recognizer.batch_size is None
|
||||
|
||||
|
||||
class TestOcr:
|
||||
def test_set_det_min_score(self, path: mock.Mock) -> None:
|
||||
path.return_value.__truediv__.return_value.__truediv__.return_value.suffix = ".onnx"
|
||||
|
||||
text_detector = TextDetector("PP-OCRv5_mobile", min_score=0.8, cache_dir="test_cache")
|
||||
|
||||
assert text_detector.postprocess.box_thresh == 0.8
|
||||
|
||||
def test_set_rec_min_score(self, path: mock.Mock) -> None:
|
||||
path.return_value.__truediv__.return_value.__truediv__.return_value.suffix = ".onnx"
|
||||
|
||||
text_recognizer = TextRecognizer("PP-OCRv5_mobile", min_score=0.8, cache_dir="test_cache")
|
||||
|
||||
assert text_recognizer.min_score == 0.8
|
||||
|
||||
def test_set_rec_set_default_max_batch_size(
|
||||
self, ort_session: mock.Mock, path: mock.Mock, mocker: MockerFixture
|
||||
) -> None:
|
||||
path.return_value.__truediv__.return_value.__truediv__.return_value.suffix = ".onnx"
|
||||
mocker.patch("immich_ml.models.base.InferenceModel.download")
|
||||
rapid_recognizer = mocker.patch("immich_ml.models.ocr.recognition.RapidTextRecognizer")
|
||||
|
||||
text_recognizer = TextRecognizer("PP-OCRv5_mobile", cache_dir="test_cache")
|
||||
text_recognizer.load()
|
||||
|
||||
rapid_recognizer.assert_called_once_with(
|
||||
OcrOptions(session=ort_session.return_value, rec_batch_num=6, rec_img_shape=(3, 48, 320))
|
||||
)
|
||||
|
||||
def test_set_custom_max_batch_size(self, ort_session: mock.Mock, path: mock.Mock, mocker: MockerFixture) -> None:
|
||||
path.return_value.__truediv__.return_value.__truediv__.return_value.suffix = ".onnx"
|
||||
mocker.patch("immich_ml.models.base.InferenceModel.download")
|
||||
rapid_recognizer = mocker.patch("immich_ml.models.ocr.recognition.RapidTextRecognizer")
|
||||
mocker.patch.object(settings, "max_batch_size", MaxBatchSize(ocr=4))
|
||||
|
||||
text_recognizer = TextRecognizer("PP-OCRv5_mobile", cache_dir="test_cache")
|
||||
text_recognizer.load()
|
||||
|
||||
rapid_recognizer.assert_called_once_with(
|
||||
OcrOptions(session=ort_session.return_value, rec_batch_num=4, rec_img_shape=(3, 48, 320))
|
||||
)
|
||||
|
||||
def test_ignore_other_custom_max_batch_size(
|
||||
self, ort_session: mock.Mock, path: mock.Mock, mocker: MockerFixture
|
||||
) -> None:
|
||||
path.return_value.__truediv__.return_value.__truediv__.return_value.suffix = ".onnx"
|
||||
mocker.patch("immich_ml.models.base.InferenceModel.download")
|
||||
rapid_recognizer = mocker.patch("immich_ml.models.ocr.recognition.RapidTextRecognizer")
|
||||
mocker.patch.object(settings, "max_batch_size", MaxBatchSize(facial_recognition=3))
|
||||
|
||||
text_recognizer = TextRecognizer("PP-OCRv5_mobile", cache_dir="test_cache")
|
||||
text_recognizer.load()
|
||||
|
||||
rapid_recognizer.assert_called_once_with(
|
||||
OcrOptions(session=ort_session.return_value, rec_batch_num=6, rec_img_shape=(3, 48, 320))
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCache:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user