mirror of
https://github.com/immich-app/immich.git
synced 2025-05-31 20:25:32 -04:00
fix(ml): load models in separate threads (#4034)
* load models in thread * set clip mode logs to debug level * updated tests * made fixtures slightly less ugly * moved responses to json file * formatting
This commit is contained in:
parent
f1db257628
commit
258b98c262
@ -1,4 +1,5 @@
|
|||||||
from typing import Iterator, TypeAlias
|
import json
|
||||||
|
from typing import Any, Iterator, TypeAlias
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -31,3 +32,8 @@ def mock_get_model() -> Iterator[mock.Mock]:
|
|||||||
def deployed_app() -> TestClient:
|
def deployed_app() -> TestClient:
|
||||||
init_state()
|
init_state()
|
||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def responses() -> dict[str, Any]:
|
||||||
|
return json.load(open("responses.json", "r"))
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import threading
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from zipfile import BadZipFile
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
from fastapi import FastAPI, Form, HTTPException, UploadFile
|
from fastapi import FastAPI, Form, HTTPException, UploadFile
|
||||||
from fastapi.responses import ORJSONResponse
|
from fastapi.responses import ORJSONResponse
|
||||||
|
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile # type: ignore
|
||||||
from starlette.formparsers import MultiPartParser
|
from starlette.formparsers import MultiPartParser
|
||||||
|
|
||||||
from app.models.base import InferenceModel
|
from app.models.base import InferenceModel
|
||||||
@ -31,6 +34,7 @@ def init_state() -> None:
|
|||||||
)
|
)
|
||||||
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
|
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
|
||||||
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
|
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
|
||||||
|
app.state.locks = {model_type: threading.Lock() for model_type in ModelType}
|
||||||
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
|
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
|
||||||
|
|
||||||
|
|
||||||
@ -63,14 +67,49 @@ async def predict(
|
|||||||
inputs = text
|
inputs = text
|
||||||
else:
|
else:
|
||||||
raise HTTPException(400, "Either image or text must be provided")
|
raise HTTPException(400, "Either image or text must be provided")
|
||||||
|
try:
|
||||||
|
kwargs = orjson.loads(options)
|
||||||
|
except orjson.JSONDecodeError:
|
||||||
|
raise HTTPException(400, f"Invalid options JSON: {options}")
|
||||||
|
|
||||||
model: InferenceModel = await app.state.model_cache.get(model_name, model_type, **orjson.loads(options))
|
model = await load(await app.state.model_cache.get(model_name, model_type, **kwargs))
|
||||||
|
model.configure(**kwargs)
|
||||||
outputs = await run(model, inputs)
|
outputs = await run(model, inputs)
|
||||||
return ORJSONResponse(outputs)
|
return ORJSONResponse(outputs)
|
||||||
|
|
||||||
|
|
||||||
async def run(model: InferenceModel, inputs: Any) -> Any:
|
async def run(model: InferenceModel, inputs: Any) -> Any:
|
||||||
if app.state.thread_pool is not None:
|
if app.state.thread_pool is None:
|
||||||
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
|
|
||||||
else:
|
|
||||||
return model.predict(inputs)
|
return model.predict(inputs)
|
||||||
|
|
||||||
|
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
|
||||||
|
|
||||||
|
|
||||||
|
async def load(model: InferenceModel) -> InferenceModel:
|
||||||
|
if model.loaded:
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _load() -> None:
|
||||||
|
with app.state.locks[model.model_type]:
|
||||||
|
model.load()
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
try:
|
||||||
|
if app.state.thread_pool is None:
|
||||||
|
model.load()
|
||||||
|
else:
|
||||||
|
await loop.run_in_executor(app.state.thread_pool, _load)
|
||||||
|
return model
|
||||||
|
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
|
||||||
|
log.warn(
|
||||||
|
(
|
||||||
|
f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'."
|
||||||
|
"Clearing cache and retrying."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model.clear_cache()
|
||||||
|
if app.state.thread_pool is None:
|
||||||
|
model.load()
|
||||||
|
else:
|
||||||
|
await loop.run_in_executor(app.state.thread_pool, _load)
|
||||||
|
return model
|
||||||
|
@ -5,10 +5,8 @@ from abc import ABC, abstractmethod
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import rmtree
|
from shutil import rmtree
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from zipfile import BadZipFile
|
|
||||||
|
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile # type: ignore
|
|
||||||
|
|
||||||
from ..config import get_cache_dir, log, settings
|
from ..config import get_cache_dir, log, settings
|
||||||
from ..schemas import ModelType
|
from ..schemas import ModelType
|
||||||
@ -21,16 +19,13 @@ class InferenceModel(ABC):
|
|||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
cache_dir: Path | str | None = None,
|
cache_dir: Path | str | None = None,
|
||||||
eager: bool = True,
|
|
||||||
inter_op_num_threads: int = settings.model_inter_op_threads,
|
inter_op_num_threads: int = settings.model_inter_op_threads,
|
||||||
intra_op_num_threads: int = settings.model_intra_op_threads,
|
intra_op_num_threads: int = settings.model_intra_op_threads,
|
||||||
**model_kwargs: Any,
|
**model_kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self._loaded = False
|
self.loaded = False
|
||||||
self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type)
|
self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type)
|
||||||
loader = self.load if eager else self.download
|
|
||||||
|
|
||||||
self.providers = model_kwargs.pop("providers", ["CPUExecutionProvider"])
|
self.providers = model_kwargs.pop("providers", ["CPUExecutionProvider"])
|
||||||
# don't pre-allocate more memory than needed
|
# don't pre-allocate more memory than needed
|
||||||
self.provider_options = model_kwargs.pop(
|
self.provider_options = model_kwargs.pop(
|
||||||
@ -55,34 +50,23 @@ class InferenceModel(ABC):
|
|||||||
self.sess_options.intra_op_num_threads = intra_op_num_threads
|
self.sess_options.intra_op_num_threads = intra_op_num_threads
|
||||||
self.sess_options.enable_cpu_mem_arena = False
|
self.sess_options.enable_cpu_mem_arena = False
|
||||||
|
|
||||||
try:
|
def download(self) -> None:
|
||||||
loader(**model_kwargs)
|
|
||||||
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
|
|
||||||
log.warn(
|
|
||||||
(
|
|
||||||
f"Failed to load {self.model_type.replace('_', ' ')} model '{self.model_name}'."
|
|
||||||
"Clearing cache and retrying."
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.clear_cache()
|
|
||||||
loader(**model_kwargs)
|
|
||||||
|
|
||||||
def download(self, **model_kwargs: Any) -> None:
|
|
||||||
if not self.cached:
|
if not self.cached:
|
||||||
log.info(
|
log.info(
|
||||||
(f"Downloading {self.model_type.replace('_', ' ')} model '{self.model_name}'." "This may take a while.")
|
(f"Downloading {self.model_type.replace('-', ' ')} model '{self.model_name}'." "This may take a while.")
|
||||||
)
|
)
|
||||||
self._download(**model_kwargs)
|
self._download()
|
||||||
|
|
||||||
def load(self, **model_kwargs: Any) -> None:
|
def load(self) -> None:
|
||||||
self.download(**model_kwargs)
|
if self.loaded:
|
||||||
self._load(**model_kwargs)
|
return
|
||||||
self._loaded = True
|
self.download()
|
||||||
|
log.info(f"Loading {self.model_type.replace('-', ' ')} model '{self.model_name}'")
|
||||||
|
self._load()
|
||||||
|
self.loaded = True
|
||||||
|
|
||||||
def predict(self, inputs: Any, **model_kwargs: Any) -> Any:
|
def predict(self, inputs: Any, **model_kwargs: Any) -> Any:
|
||||||
if not self._loaded:
|
self.load()
|
||||||
log.info(f"Loading {self.model_type.replace('_', ' ')} model '{self.model_name}'")
|
|
||||||
self.load()
|
|
||||||
if model_kwargs:
|
if model_kwargs:
|
||||||
self.configure(**model_kwargs)
|
self.configure(**model_kwargs)
|
||||||
return self._predict(inputs)
|
return self._predict(inputs)
|
||||||
@ -95,11 +79,11 @@ class InferenceModel(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _download(self, **model_kwargs: Any) -> None:
|
def _download(self) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _load(self, **model_kwargs: Any) -> None:
|
def _load(self) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -17,7 +17,7 @@ class ModelCache:
|
|||||||
revalidate: bool = False,
|
revalidate: bool = False,
|
||||||
timeout: int | None = None,
|
timeout: int | None = None,
|
||||||
profiling: bool = False,
|
profiling: bool = False,
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
ttl: Unloads model after this duration. Disabled if None. Defaults to None.
|
ttl: Unloads model after this duration. Disabled if None. Defaults to None.
|
||||||
|
@ -42,7 +42,7 @@ class CLIPEncoder(InferenceModel):
|
|||||||
jina_model_name = self._get_jina_model_name(model_name)
|
jina_model_name = self._get_jina_model_name(model_name)
|
||||||
super().__init__(jina_model_name, cache_dir, **model_kwargs)
|
super().__init__(jina_model_name, cache_dir, **model_kwargs)
|
||||||
|
|
||||||
def _download(self, **model_kwargs: Any) -> None:
|
def _download(self) -> None:
|
||||||
models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name]
|
models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name]
|
||||||
text_onnx_path = self.cache_dir / "textual.onnx"
|
text_onnx_path = self.cache_dir / "textual.onnx"
|
||||||
vision_onnx_path = self.cache_dir / "visual.onnx"
|
vision_onnx_path = self.cache_dir / "visual.onnx"
|
||||||
@ -53,8 +53,9 @@ class CLIPEncoder(InferenceModel):
|
|||||||
if not vision_onnx_path.is_file():
|
if not vision_onnx_path.is_file():
|
||||||
self._download_model(*models[1])
|
self._download_model(*models[1])
|
||||||
|
|
||||||
def _load(self, **model_kwargs: Any) -> None:
|
def _load(self) -> None:
|
||||||
if self.mode == "text" or self.mode is None:
|
if self.mode == "text" or self.mode is None:
|
||||||
|
log.debug(f"Loading clip text model '{self.model_name}'")
|
||||||
self.text_model = ort.InferenceSession(
|
self.text_model = ort.InferenceSession(
|
||||||
self.cache_dir / "textual.onnx",
|
self.cache_dir / "textual.onnx",
|
||||||
sess_options=self.sess_options,
|
sess_options=self.sess_options,
|
||||||
@ -65,6 +66,7 @@ class CLIPEncoder(InferenceModel):
|
|||||||
self.tokenizer = Tokenizer(self.model_name)
|
self.tokenizer = Tokenizer(self.model_name)
|
||||||
|
|
||||||
if self.mode == "vision" or self.mode is None:
|
if self.mode == "vision" or self.mode is None:
|
||||||
|
log.debug(f"Loading clip vision model '{self.model_name}'")
|
||||||
self.vision_model = ort.InferenceSession(
|
self.vision_model = ort.InferenceSession(
|
||||||
self.cache_dir / "visual.onnx",
|
self.cache_dir / "visual.onnx",
|
||||||
sess_options=self.sess_options,
|
sess_options=self.sess_options,
|
||||||
|
@ -26,7 +26,7 @@ class FaceRecognizer(InferenceModel):
|
|||||||
self.min_score = model_kwargs.pop("minScore", min_score)
|
self.min_score = model_kwargs.pop("minScore", min_score)
|
||||||
super().__init__(model_name, cache_dir, **model_kwargs)
|
super().__init__(model_name, cache_dir, **model_kwargs)
|
||||||
|
|
||||||
def _download(self, **model_kwargs: Any) -> None:
|
def _download(self) -> None:
|
||||||
zip_file = self.cache_dir / f"{self.model_name}.zip"
|
zip_file = self.cache_dir / f"{self.model_name}.zip"
|
||||||
download_file(f"{BASE_REPO_URL}/{self.model_name}.zip", zip_file)
|
download_file(f"{BASE_REPO_URL}/{self.model_name}.zip", zip_file)
|
||||||
with zipfile.ZipFile(zip_file, "r") as zip:
|
with zipfile.ZipFile(zip_file, "r") as zip:
|
||||||
@ -36,7 +36,7 @@ class FaceRecognizer(InferenceModel):
|
|||||||
zip.extractall(self.cache_dir, members=[det_file, rec_file])
|
zip.extractall(self.cache_dir, members=[det_file, rec_file])
|
||||||
zip_file.unlink()
|
zip_file.unlink()
|
||||||
|
|
||||||
def _load(self, **model_kwargs: Any) -> None:
|
def _load(self) -> None:
|
||||||
try:
|
try:
|
||||||
det_file = next(self.cache_dir.glob("det_*.onnx"))
|
det_file = next(self.cache_dir.glob("det_*.onnx"))
|
||||||
rec_file = next(self.cache_dir.glob("w600k_*.onnx"))
|
rec_file = next(self.cache_dir.glob("w600k_*.onnx"))
|
||||||
|
@ -26,7 +26,7 @@ class ImageClassifier(InferenceModel):
|
|||||||
self.min_score = model_kwargs.pop("minScore", min_score)
|
self.min_score = model_kwargs.pop("minScore", min_score)
|
||||||
super().__init__(model_name, cache_dir, **model_kwargs)
|
super().__init__(model_name, cache_dir, **model_kwargs)
|
||||||
|
|
||||||
def _download(self, **model_kwargs: Any) -> None:
|
def _download(self) -> None:
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
cache_dir=self.cache_dir,
|
cache_dir=self.cache_dir,
|
||||||
repo_id=self.model_name,
|
repo_id=self.model_name,
|
||||||
@ -35,10 +35,10 @@ class ImageClassifier(InferenceModel):
|
|||||||
local_dir_use_symlinks=True,
|
local_dir_use_symlinks=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load(self, **model_kwargs: Any) -> None:
|
def _load(self) -> None:
|
||||||
processor = AutoImageProcessor.from_pretrained(self.cache_dir, cache_dir=self.cache_dir)
|
processor = AutoImageProcessor.from_pretrained(self.cache_dir, cache_dir=self.cache_dir)
|
||||||
model_path = self.cache_dir / "model.onnx"
|
model_path = self.cache_dir / "model.onnx"
|
||||||
model_kwargs |= {
|
model_kwargs = {
|
||||||
"cache_dir": self.cache_dir,
|
"cache_dir": self.cache_dir,
|
||||||
"provider": self.providers[0],
|
"provider": self.providers[0],
|
||||||
"provider_options": self.provider_options[0],
|
"provider_options": self.provider_options[0],
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import TypeAlias
|
from typing import Any, TypeAlias
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -31,23 +31,6 @@ class TestImageClassifier:
|
|||||||
{"label": "probably a virus", "score": 0.01},
|
{"label": "probably a virus", "score": 0.01},
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_eager_init(self, mocker: MockerFixture) -> None:
|
|
||||||
mocker.patch.object(ImageClassifier, "download")
|
|
||||||
mock_load = mocker.patch.object(ImageClassifier, "load")
|
|
||||||
classifier = ImageClassifier("test_model_name", cache_dir="test_cache", eager=True, test_arg="test_arg")
|
|
||||||
|
|
||||||
assert classifier.model_name == "test_model_name"
|
|
||||||
mock_load.assert_called_once_with(test_arg="test_arg")
|
|
||||||
|
|
||||||
def test_lazy_init(self, mocker: MockerFixture) -> None:
|
|
||||||
mock_download = mocker.patch.object(ImageClassifier, "download")
|
|
||||||
mock_load = mocker.patch.object(ImageClassifier, "load")
|
|
||||||
face_model = ImageClassifier("test_model_name", cache_dir="test_cache", eager=False, test_arg="test_arg")
|
|
||||||
|
|
||||||
assert face_model.model_name == "test_model_name"
|
|
||||||
mock_download.assert_called_once_with(test_arg="test_arg")
|
|
||||||
mock_load.assert_not_called()
|
|
||||||
|
|
||||||
def test_min_score(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
|
def test_min_score(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
|
||||||
mocker.patch.object(ImageClassifier, "load")
|
mocker.patch.object(ImageClassifier, "load")
|
||||||
classifier = ImageClassifier("test_model_name", min_score=0.0)
|
classifier = ImageClassifier("test_model_name", min_score=0.0)
|
||||||
@ -74,23 +57,6 @@ class TestImageClassifier:
|
|||||||
class TestCLIP:
|
class TestCLIP:
|
||||||
embedding = np.random.rand(512).astype(np.float32)
|
embedding = np.random.rand(512).astype(np.float32)
|
||||||
|
|
||||||
def test_eager_init(self, mocker: MockerFixture) -> None:
|
|
||||||
mocker.patch.object(CLIPEncoder, "download")
|
|
||||||
mock_load = mocker.patch.object(CLIPEncoder, "load")
|
|
||||||
clip_model = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", eager=True, test_arg="test_arg")
|
|
||||||
|
|
||||||
assert clip_model.model_name == "ViT-B-32::openai"
|
|
||||||
mock_load.assert_called_once_with(test_arg="test_arg")
|
|
||||||
|
|
||||||
def test_lazy_init(self, mocker: MockerFixture) -> None:
|
|
||||||
mock_download = mocker.patch.object(CLIPEncoder, "download")
|
|
||||||
mock_load = mocker.patch.object(CLIPEncoder, "load")
|
|
||||||
clip_model = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", eager=False, test_arg="test_arg")
|
|
||||||
|
|
||||||
assert clip_model.model_name == "ViT-B-32::openai"
|
|
||||||
mock_download.assert_called_once_with(test_arg="test_arg")
|
|
||||||
mock_load.assert_not_called()
|
|
||||||
|
|
||||||
def test_basic_image(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
|
def test_basic_image(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
|
||||||
mocker.patch.object(CLIPEncoder, "download")
|
mocker.patch.object(CLIPEncoder, "download")
|
||||||
mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True)
|
mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True)
|
||||||
@ -119,23 +85,6 @@ class TestCLIP:
|
|||||||
|
|
||||||
|
|
||||||
class TestFaceRecognition:
|
class TestFaceRecognition:
|
||||||
def test_eager_init(self, mocker: MockerFixture) -> None:
|
|
||||||
mocker.patch.object(FaceRecognizer, "download")
|
|
||||||
mock_load = mocker.patch.object(FaceRecognizer, "load")
|
|
||||||
face_model = FaceRecognizer("test_model_name", cache_dir="test_cache", eager=True, test_arg="test_arg")
|
|
||||||
|
|
||||||
assert face_model.model_name == "test_model_name"
|
|
||||||
mock_load.assert_called_once_with(test_arg="test_arg")
|
|
||||||
|
|
||||||
def test_lazy_init(self, mocker: MockerFixture) -> None:
|
|
||||||
mock_download = mocker.patch.object(FaceRecognizer, "download")
|
|
||||||
mock_load = mocker.patch.object(FaceRecognizer, "load")
|
|
||||||
face_model = FaceRecognizer("test_model_name", cache_dir="test_cache", eager=False, test_arg="test_arg")
|
|
||||||
|
|
||||||
assert face_model.model_name == "test_model_name"
|
|
||||||
mock_download.assert_called_once_with(test_arg="test_arg")
|
|
||||||
mock_load.assert_not_called()
|
|
||||||
|
|
||||||
def test_set_min_score(self, mocker: MockerFixture) -> None:
|
def test_set_min_score(self, mocker: MockerFixture) -> None:
|
||||||
mocker.patch.object(FaceRecognizer, "load")
|
mocker.patch.object(FaceRecognizer, "load")
|
||||||
face_recognizer = FaceRecognizer("test_model_name", cache_dir="test_cache", min_score=0.5)
|
face_recognizer = FaceRecognizer("test_model_name", cache_dir="test_cache", min_score=0.5)
|
||||||
@ -220,45 +169,64 @@ class TestCache:
|
|||||||
reason="More time-consuming since it deploys the app and loads models.",
|
reason="More time-consuming since it deploys the app and loads models.",
|
||||||
)
|
)
|
||||||
class TestEndpoints:
|
class TestEndpoints:
|
||||||
def test_tagging_endpoint(self, pil_image: Image.Image, deployed_app: TestClient) -> None:
|
def test_tagging_endpoint(
|
||||||
|
self, pil_image: Image.Image, responses: dict[str, Any], deployed_app: TestClient
|
||||||
|
) -> None:
|
||||||
byte_image = BytesIO()
|
byte_image = BytesIO()
|
||||||
pil_image.save(byte_image, format="jpeg")
|
pil_image.save(byte_image, format="jpeg")
|
||||||
headers = {"Content-Type": "image/jpg"}
|
|
||||||
response = deployed_app.post(
|
response = deployed_app.post(
|
||||||
"http://localhost:3003/image-classifier/tag-image",
|
"http://localhost:3003/predict",
|
||||||
content=byte_image.getvalue(),
|
data={
|
||||||
headers=headers,
|
"modelName": "microsoft/resnet-50",
|
||||||
|
"modelType": "image-classification",
|
||||||
|
"options": json.dumps({"minScore": 0.0}),
|
||||||
|
},
|
||||||
|
files={"image": byte_image.getvalue()},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
assert response.json() == responses["image-classification"]
|
||||||
|
|
||||||
def test_clip_image_endpoint(self, pil_image: Image.Image, deployed_app: TestClient) -> None:
|
def test_clip_image_endpoint(
|
||||||
|
self, pil_image: Image.Image, responses: dict[str, Any], deployed_app: TestClient
|
||||||
|
) -> None:
|
||||||
byte_image = BytesIO()
|
byte_image = BytesIO()
|
||||||
pil_image.save(byte_image, format="jpeg")
|
pil_image.save(byte_image, format="jpeg")
|
||||||
headers = {"Content-Type": "image/jpg"}
|
|
||||||
response = deployed_app.post(
|
response = deployed_app.post(
|
||||||
"http://localhost:3003/sentence-transformer/encode-image",
|
"http://localhost:3003/predict",
|
||||||
content=byte_image.getvalue(),
|
data={"modelName": "ViT-B-32::openai", "modelType": "clip", "options": json.dumps({"mode": "vision"})},
|
||||||
headers=headers,
|
files={"image": byte_image.getvalue()},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
assert response.json() == responses["clip"]["image"]
|
||||||
|
|
||||||
def test_clip_text_endpoint(self, deployed_app: TestClient) -> None:
|
def test_clip_text_endpoint(self, responses: dict[str, Any], deployed_app: TestClient) -> None:
|
||||||
response = deployed_app.post(
|
response = deployed_app.post(
|
||||||
"http://localhost:3003/sentence-transformer/encode-text",
|
"http://localhost:3003/predict",
|
||||||
json={"text": "test search query"},
|
data={
|
||||||
|
"modelName": "ViT-B-32::openai",
|
||||||
|
"modelType": "clip",
|
||||||
|
"text": "test search query",
|
||||||
|
"options": json.dumps({"mode": "text"}),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
assert response.json() == responses["clip"]["text"]
|
||||||
|
|
||||||
def test_face_endpoint(self, pil_image: Image.Image, deployed_app: TestClient) -> None:
|
def test_face_endpoint(self, pil_image: Image.Image, responses: dict[str, Any], deployed_app: TestClient) -> None:
|
||||||
byte_image = BytesIO()
|
byte_image = BytesIO()
|
||||||
pil_image.save(byte_image, format="jpeg")
|
pil_image.save(byte_image, format="jpeg")
|
||||||
headers = {"Content-Type": "image/jpg"}
|
|
||||||
response = deployed_app.post(
|
response = deployed_app.post(
|
||||||
"http://localhost:3003/facial-recognition/detect-faces",
|
"http://localhost:3003/predict",
|
||||||
content=byte_image.getvalue(),
|
data={
|
||||||
headers=headers,
|
"modelName": "buffalo_l",
|
||||||
|
"modelType": "facial-recognition",
|
||||||
|
"options": json.dumps({"minScore": 0.034}),
|
||||||
|
},
|
||||||
|
files={"image": byte_image.getvalue()},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
assert response.json() == responses["facial-recognition"]
|
||||||
|
|
||||||
|
|
||||||
def test_sess_options() -> None:
|
def test_sess_options() -> None:
|
||||||
|
1570
machine-learning/responses.json
Normal file
1570
machine-learning/responses.json
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user