mirror of
https://github.com/immich-app/immich.git
synced 2025-11-13 01:56:53 -05:00
* fix resizing, use pillow * unused import * linting * lanczos * optimizations fused operations unused import
125 lines
4.8 KiB
Python
125 lines
4.8 KiB
Python
from typing import Any
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from numpy.typing import NDArray
|
|
from PIL import Image
|
|
from rapidocr.ch_ppocr_det.utils import DBPostProcess
|
|
from rapidocr.inference_engine.base import FileInfo, InferSession
|
|
from rapidocr.utils import DownloadFile, DownloadFileInput
|
|
from rapidocr.utils.typings import EngineType, LangDet, OCRVersion, TaskType
|
|
from rapidocr.utils.typings import ModelType as RapidModelType
|
|
|
|
from immich_ml.config import log
|
|
from immich_ml.models.base import InferenceModel
|
|
from immich_ml.schemas import ModelFormat, ModelSession, ModelTask, ModelType
|
|
from immich_ml.sessions.ort import OrtSession
|
|
|
|
from .schemas import TextDetectionOutput
|
|
|
|
|
|
class TextDetector(InferenceModel):
|
|
depends = []
|
|
identity = (ModelType.DETECTION, ModelTask.OCR)
|
|
|
|
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
|
|
super().__init__(model_name, **model_kwargs, model_format=ModelFormat.ONNX)
|
|
self.max_resolution = 736
|
|
self.mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
|
|
self.std_inv = np.float32(1.0) / (np.array([0.5, 0.5, 0.5], dtype=np.float32) * 255.0)
|
|
self._empty: TextDetectionOutput = {
|
|
"boxes": np.empty(0, dtype=np.float32),
|
|
"scores": np.empty(0, dtype=np.float32),
|
|
}
|
|
self.postprocess = DBPostProcess(
|
|
thresh=0.3,
|
|
box_thresh=model_kwargs.get("minScore", 0.5),
|
|
max_candidates=1000,
|
|
unclip_ratio=1.6,
|
|
use_dilation=True,
|
|
score_mode="fast",
|
|
)
|
|
|
|
def _download(self) -> None:
|
|
model_info = InferSession.get_model_url(
|
|
FileInfo(
|
|
engine_type=EngineType.ONNXRUNTIME,
|
|
ocr_version=OCRVersion.PPOCRV5,
|
|
task_type=TaskType.DET,
|
|
lang_type=LangDet.CH,
|
|
model_type=RapidModelType.MOBILE if "mobile" in self.model_name else RapidModelType.SERVER,
|
|
)
|
|
)
|
|
download_params = DownloadFileInput(
|
|
file_url=model_info["model_dir"],
|
|
sha256=model_info["SHA256"],
|
|
save_path=self.model_path,
|
|
logger=log,
|
|
)
|
|
DownloadFile.run(download_params)
|
|
|
|
def _load(self) -> ModelSession:
|
|
# TODO: support other runtime sessions
|
|
return OrtSession(self.model_path)
|
|
|
|
# partly adapted from RapidOCR
|
|
def _predict(self, inputs: Image.Image) -> TextDetectionOutput:
|
|
w, h = inputs.size
|
|
if w < 32 or h < 32:
|
|
return self._empty
|
|
out = self.session.run(None, {"x": self._transform(inputs)})[0]
|
|
boxes, scores = self.postprocess(out, (h, w))
|
|
if len(boxes) == 0:
|
|
return self._empty
|
|
return {
|
|
"boxes": self.sorted_boxes(boxes),
|
|
"scores": np.array(scores, dtype=np.float32),
|
|
}
|
|
|
|
# adapted from RapidOCR
|
|
def _transform(self, img: Image.Image) -> NDArray[np.float32]:
|
|
if img.height < img.width:
|
|
ratio = float(self.max_resolution) / img.height
|
|
else:
|
|
ratio = float(self.max_resolution) / img.width
|
|
|
|
resize_h = int(img.height * ratio)
|
|
resize_w = int(img.width * ratio)
|
|
|
|
resize_h = int(round(resize_h / 32) * 32)
|
|
resize_w = int(round(resize_w / 32) * 32)
|
|
resized_img = img.resize((int(resize_w), int(resize_h)), resample=Image.Resampling.LANCZOS)
|
|
|
|
img_np: NDArray[np.float32] = cv2.cvtColor(np.array(resized_img, dtype=np.float32), cv2.COLOR_RGB2BGR) # type: ignore
|
|
img_np -= self.mean
|
|
img_np *= self.std_inv
|
|
img_np = np.transpose(img_np, (2, 0, 1))
|
|
return np.expand_dims(img_np, axis=0)
|
|
|
|
def sorted_boxes(self, dt_boxes: NDArray[np.float32]) -> NDArray[np.float32]:
|
|
if len(dt_boxes) == 0:
|
|
return dt_boxes
|
|
|
|
# Sort by y, then identify lines, then sort by (line, x)
|
|
y_order = np.argsort(dt_boxes[:, 0, 1], kind="stable")
|
|
sorted_y = dt_boxes[y_order, 0, 1]
|
|
|
|
line_ids = np.empty(len(dt_boxes), dtype=np.int32)
|
|
line_ids[0] = 0
|
|
np.cumsum(np.abs(np.diff(sorted_y)) >= 10, out=line_ids[1:])
|
|
|
|
# Create composite sort key for final ordering
|
|
# Shift line_ids by large factor, add x for tie-breaking
|
|
sort_key = line_ids[y_order] * 1e6 + dt_boxes[y_order, 0, 0]
|
|
final_order = np.argsort(sort_key, kind="stable")
|
|
sorted_boxes: NDArray[np.float32] = dt_boxes[y_order[final_order]]
|
|
return sorted_boxes
|
|
|
|
def configure(self, **kwargs: Any) -> None:
|
|
if (max_resolution := kwargs.get("maxResolution")) is not None:
|
|
self.max_resolution = max_resolution
|
|
if (min_score := kwargs.get("minScore")) is not None:
|
|
self.postprocess.box_thresh = min_score
|
|
if (score_mode := kwargs.get("scoreMode")) is not None:
|
|
self.postprocess.score_mode = score_mode
|