fix(ml): load models in separate threads (#4034)

* load models in thread

* set clip mode logs to debug level

* updated tests

* made fixtures slightly less ugly

* moved responses to json file

* formatting
This commit is contained in:
Mert 2023-09-09 05:02:44 -04:00 committed by GitHub
parent f1db257628
commit 258b98c262
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 1683 additions and 114 deletions

View File

@ -1,4 +1,5 @@
from typing import Iterator, TypeAlias import json
from typing import Any, Iterator, TypeAlias
from unittest import mock from unittest import mock
import numpy as np import numpy as np
@ -31,3 +32,8 @@ def mock_get_model() -> Iterator[mock.Mock]:
def deployed_app() -> TestClient: def deployed_app() -> TestClient:
init_state() init_state()
return TestClient(app) return TestClient(app)
@pytest.fixture(scope="session")
def responses() -> dict[str, Any]:
return json.load(open("responses.json", "r"))

View File

@ -1,10 +1,13 @@
import asyncio import asyncio
import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Any from typing import Any
from zipfile import BadZipFile
import orjson import orjson
from fastapi import FastAPI, Form, HTTPException, UploadFile from fastapi import FastAPI, Form, HTTPException, UploadFile
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile # type: ignore
from starlette.formparsers import MultiPartParser from starlette.formparsers import MultiPartParser
from app.models.base import InferenceModel from app.models.base import InferenceModel
@ -31,6 +34,7 @@ def init_state() -> None:
) )
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
app.state.locks = {model_type: threading.Lock() for model_type in ModelType}
log.info(f"Initialized request thread pool with {settings.request_threads} threads.") log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
@ -63,14 +67,49 @@ async def predict(
inputs = text inputs = text
else: else:
raise HTTPException(400, "Either image or text must be provided") raise HTTPException(400, "Either image or text must be provided")
try:
kwargs = orjson.loads(options)
except orjson.JSONDecodeError:
raise HTTPException(400, f"Invalid options JSON: {options}")
model: InferenceModel = await app.state.model_cache.get(model_name, model_type, **orjson.loads(options)) model = await load(await app.state.model_cache.get(model_name, model_type, **kwargs))
model.configure(**kwargs)
outputs = await run(model, inputs) outputs = await run(model, inputs)
return ORJSONResponse(outputs) return ORJSONResponse(outputs)
async def run(model: InferenceModel, inputs: Any) -> Any: async def run(model: InferenceModel, inputs: Any) -> Any:
if app.state.thread_pool is not None: if app.state.thread_pool is None:
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
else:
return model.predict(inputs) return model.predict(inputs)
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
async def load(model: InferenceModel) -> InferenceModel:
if model.loaded:
return model
def _load() -> None:
with app.state.locks[model.model_type]:
model.load()
loop = asyncio.get_running_loop()
try:
if app.state.thread_pool is None:
model.load()
else:
await loop.run_in_executor(app.state.thread_pool, _load)
return model
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
log.warn(
(
f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'."
"Clearing cache and retrying."
)
)
model.clear_cache()
if app.state.thread_pool is None:
model.load()
else:
await loop.run_in_executor(app.state.thread_pool, _load)
return model

View File

@ -5,10 +5,8 @@ from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from shutil import rmtree from shutil import rmtree
from typing import Any from typing import Any
from zipfile import BadZipFile
import onnxruntime as ort import onnxruntime as ort
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile # type: ignore
from ..config import get_cache_dir, log, settings from ..config import get_cache_dir, log, settings
from ..schemas import ModelType from ..schemas import ModelType
@ -21,16 +19,13 @@ class InferenceModel(ABC):
self, self,
model_name: str, model_name: str,
cache_dir: Path | str | None = None, cache_dir: Path | str | None = None,
eager: bool = True,
inter_op_num_threads: int = settings.model_inter_op_threads, inter_op_num_threads: int = settings.model_inter_op_threads,
intra_op_num_threads: int = settings.model_intra_op_threads, intra_op_num_threads: int = settings.model_intra_op_threads,
**model_kwargs: Any, **model_kwargs: Any,
) -> None: ) -> None:
self.model_name = model_name self.model_name = model_name
self._loaded = False self.loaded = False
self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type) self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type)
loader = self.load if eager else self.download
self.providers = model_kwargs.pop("providers", ["CPUExecutionProvider"]) self.providers = model_kwargs.pop("providers", ["CPUExecutionProvider"])
# don't pre-allocate more memory than needed # don't pre-allocate more memory than needed
self.provider_options = model_kwargs.pop( self.provider_options = model_kwargs.pop(
@ -55,34 +50,23 @@ class InferenceModel(ABC):
self.sess_options.intra_op_num_threads = intra_op_num_threads self.sess_options.intra_op_num_threads = intra_op_num_threads
self.sess_options.enable_cpu_mem_arena = False self.sess_options.enable_cpu_mem_arena = False
try: def download(self) -> None:
loader(**model_kwargs)
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
log.warn(
(
f"Failed to load {self.model_type.replace('_', ' ')} model '{self.model_name}'."
"Clearing cache and retrying."
)
)
self.clear_cache()
loader(**model_kwargs)
def download(self, **model_kwargs: Any) -> None:
if not self.cached: if not self.cached:
log.info( log.info(
(f"Downloading {self.model_type.replace('_', ' ')} model '{self.model_name}'." "This may take a while.") (f"Downloading {self.model_type.replace('-', ' ')} model '{self.model_name}'." "This may take a while.")
) )
self._download(**model_kwargs) self._download()
def load(self, **model_kwargs: Any) -> None: def load(self) -> None:
self.download(**model_kwargs) if self.loaded:
self._load(**model_kwargs) return
self._loaded = True self.download()
log.info(f"Loading {self.model_type.replace('-', ' ')} model '{self.model_name}'")
self._load()
self.loaded = True
def predict(self, inputs: Any, **model_kwargs: Any) -> Any: def predict(self, inputs: Any, **model_kwargs: Any) -> Any:
if not self._loaded: self.load()
log.info(f"Loading {self.model_type.replace('_', ' ')} model '{self.model_name}'")
self.load()
if model_kwargs: if model_kwargs:
self.configure(**model_kwargs) self.configure(**model_kwargs)
return self._predict(inputs) return self._predict(inputs)
@ -95,11 +79,11 @@ class InferenceModel(ABC):
pass pass
@abstractmethod @abstractmethod
def _download(self, **model_kwargs: Any) -> None: def _download(self) -> None:
... ...
@abstractmethod @abstractmethod
def _load(self, **model_kwargs: Any) -> None: def _load(self) -> None:
... ...
@property @property

View File

@ -17,7 +17,7 @@ class ModelCache:
revalidate: bool = False, revalidate: bool = False,
timeout: int | None = None, timeout: int | None = None,
profiling: bool = False, profiling: bool = False,
): ) -> None:
""" """
Args: Args:
ttl: Unloads model after this duration. Disabled if None. Defaults to None. ttl: Unloads model after this duration. Disabled if None. Defaults to None.

View File

@ -42,7 +42,7 @@ class CLIPEncoder(InferenceModel):
jina_model_name = self._get_jina_model_name(model_name) jina_model_name = self._get_jina_model_name(model_name)
super().__init__(jina_model_name, cache_dir, **model_kwargs) super().__init__(jina_model_name, cache_dir, **model_kwargs)
def _download(self, **model_kwargs: Any) -> None: def _download(self) -> None:
models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name] models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name]
text_onnx_path = self.cache_dir / "textual.onnx" text_onnx_path = self.cache_dir / "textual.onnx"
vision_onnx_path = self.cache_dir / "visual.onnx" vision_onnx_path = self.cache_dir / "visual.onnx"
@ -53,8 +53,9 @@ class CLIPEncoder(InferenceModel):
if not vision_onnx_path.is_file(): if not vision_onnx_path.is_file():
self._download_model(*models[1]) self._download_model(*models[1])
def _load(self, **model_kwargs: Any) -> None: def _load(self) -> None:
if self.mode == "text" or self.mode is None: if self.mode == "text" or self.mode is None:
log.debug(f"Loading clip text model '{self.model_name}'")
self.text_model = ort.InferenceSession( self.text_model = ort.InferenceSession(
self.cache_dir / "textual.onnx", self.cache_dir / "textual.onnx",
sess_options=self.sess_options, sess_options=self.sess_options,
@ -65,6 +66,7 @@ class CLIPEncoder(InferenceModel):
self.tokenizer = Tokenizer(self.model_name) self.tokenizer = Tokenizer(self.model_name)
if self.mode == "vision" or self.mode is None: if self.mode == "vision" or self.mode is None:
log.debug(f"Loading clip vision model '{self.model_name}'")
self.vision_model = ort.InferenceSession( self.vision_model = ort.InferenceSession(
self.cache_dir / "visual.onnx", self.cache_dir / "visual.onnx",
sess_options=self.sess_options, sess_options=self.sess_options,

View File

@ -26,7 +26,7 @@ class FaceRecognizer(InferenceModel):
self.min_score = model_kwargs.pop("minScore", min_score) self.min_score = model_kwargs.pop("minScore", min_score)
super().__init__(model_name, cache_dir, **model_kwargs) super().__init__(model_name, cache_dir, **model_kwargs)
def _download(self, **model_kwargs: Any) -> None: def _download(self) -> None:
zip_file = self.cache_dir / f"{self.model_name}.zip" zip_file = self.cache_dir / f"{self.model_name}.zip"
download_file(f"{BASE_REPO_URL}/{self.model_name}.zip", zip_file) download_file(f"{BASE_REPO_URL}/{self.model_name}.zip", zip_file)
with zipfile.ZipFile(zip_file, "r") as zip: with zipfile.ZipFile(zip_file, "r") as zip:
@ -36,7 +36,7 @@ class FaceRecognizer(InferenceModel):
zip.extractall(self.cache_dir, members=[det_file, rec_file]) zip.extractall(self.cache_dir, members=[det_file, rec_file])
zip_file.unlink() zip_file.unlink()
def _load(self, **model_kwargs: Any) -> None: def _load(self) -> None:
try: try:
det_file = next(self.cache_dir.glob("det_*.onnx")) det_file = next(self.cache_dir.glob("det_*.onnx"))
rec_file = next(self.cache_dir.glob("w600k_*.onnx")) rec_file = next(self.cache_dir.glob("w600k_*.onnx"))

View File

@ -26,7 +26,7 @@ class ImageClassifier(InferenceModel):
self.min_score = model_kwargs.pop("minScore", min_score) self.min_score = model_kwargs.pop("minScore", min_score)
super().__init__(model_name, cache_dir, **model_kwargs) super().__init__(model_name, cache_dir, **model_kwargs)
def _download(self, **model_kwargs: Any) -> None: def _download(self) -> None:
snapshot_download( snapshot_download(
cache_dir=self.cache_dir, cache_dir=self.cache_dir,
repo_id=self.model_name, repo_id=self.model_name,
@ -35,10 +35,10 @@ class ImageClassifier(InferenceModel):
local_dir_use_symlinks=True, local_dir_use_symlinks=True,
) )
def _load(self, **model_kwargs: Any) -> None: def _load(self) -> None:
processor = AutoImageProcessor.from_pretrained(self.cache_dir, cache_dir=self.cache_dir) processor = AutoImageProcessor.from_pretrained(self.cache_dir, cache_dir=self.cache_dir)
model_path = self.cache_dir / "model.onnx" model_path = self.cache_dir / "model.onnx"
model_kwargs |= { model_kwargs = {
"cache_dir": self.cache_dir, "cache_dir": self.cache_dir,
"provider": self.providers[0], "provider": self.providers[0],
"provider_options": self.provider_options[0], "provider_options": self.provider_options[0],

View File

@ -1,11 +1,11 @@
import json
import pickle import pickle
from io import BytesIO from io import BytesIO
from typing import TypeAlias from typing import Any, TypeAlias
from unittest import mock from unittest import mock
import cv2 import cv2
import numpy as np import numpy as np
import onnxruntime as ort
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from PIL import Image from PIL import Image
@ -31,23 +31,6 @@ class TestImageClassifier:
{"label": "probably a virus", "score": 0.01}, {"label": "probably a virus", "score": 0.01},
] ]
def test_eager_init(self, mocker: MockerFixture) -> None:
mocker.patch.object(ImageClassifier, "download")
mock_load = mocker.patch.object(ImageClassifier, "load")
classifier = ImageClassifier("test_model_name", cache_dir="test_cache", eager=True, test_arg="test_arg")
assert classifier.model_name == "test_model_name"
mock_load.assert_called_once_with(test_arg="test_arg")
def test_lazy_init(self, mocker: MockerFixture) -> None:
mock_download = mocker.patch.object(ImageClassifier, "download")
mock_load = mocker.patch.object(ImageClassifier, "load")
face_model = ImageClassifier("test_model_name", cache_dir="test_cache", eager=False, test_arg="test_arg")
assert face_model.model_name == "test_model_name"
mock_download.assert_called_once_with(test_arg="test_arg")
mock_load.assert_not_called()
def test_min_score(self, pil_image: Image.Image, mocker: MockerFixture) -> None: def test_min_score(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
mocker.patch.object(ImageClassifier, "load") mocker.patch.object(ImageClassifier, "load")
classifier = ImageClassifier("test_model_name", min_score=0.0) classifier = ImageClassifier("test_model_name", min_score=0.0)
@ -74,23 +57,6 @@ class TestImageClassifier:
class TestCLIP: class TestCLIP:
embedding = np.random.rand(512).astype(np.float32) embedding = np.random.rand(512).astype(np.float32)
def test_eager_init(self, mocker: MockerFixture) -> None:
mocker.patch.object(CLIPEncoder, "download")
mock_load = mocker.patch.object(CLIPEncoder, "load")
clip_model = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", eager=True, test_arg="test_arg")
assert clip_model.model_name == "ViT-B-32::openai"
mock_load.assert_called_once_with(test_arg="test_arg")
def test_lazy_init(self, mocker: MockerFixture) -> None:
mock_download = mocker.patch.object(CLIPEncoder, "download")
mock_load = mocker.patch.object(CLIPEncoder, "load")
clip_model = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", eager=False, test_arg="test_arg")
assert clip_model.model_name == "ViT-B-32::openai"
mock_download.assert_called_once_with(test_arg="test_arg")
mock_load.assert_not_called()
def test_basic_image(self, pil_image: Image.Image, mocker: MockerFixture) -> None: def test_basic_image(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
mocker.patch.object(CLIPEncoder, "download") mocker.patch.object(CLIPEncoder, "download")
mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True) mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True)
@ -119,23 +85,6 @@ class TestCLIP:
class TestFaceRecognition: class TestFaceRecognition:
def test_eager_init(self, mocker: MockerFixture) -> None:
mocker.patch.object(FaceRecognizer, "download")
mock_load = mocker.patch.object(FaceRecognizer, "load")
face_model = FaceRecognizer("test_model_name", cache_dir="test_cache", eager=True, test_arg="test_arg")
assert face_model.model_name == "test_model_name"
mock_load.assert_called_once_with(test_arg="test_arg")
def test_lazy_init(self, mocker: MockerFixture) -> None:
mock_download = mocker.patch.object(FaceRecognizer, "download")
mock_load = mocker.patch.object(FaceRecognizer, "load")
face_model = FaceRecognizer("test_model_name", cache_dir="test_cache", eager=False, test_arg="test_arg")
assert face_model.model_name == "test_model_name"
mock_download.assert_called_once_with(test_arg="test_arg")
mock_load.assert_not_called()
def test_set_min_score(self, mocker: MockerFixture) -> None: def test_set_min_score(self, mocker: MockerFixture) -> None:
mocker.patch.object(FaceRecognizer, "load") mocker.patch.object(FaceRecognizer, "load")
face_recognizer = FaceRecognizer("test_model_name", cache_dir="test_cache", min_score=0.5) face_recognizer = FaceRecognizer("test_model_name", cache_dir="test_cache", min_score=0.5)
@ -220,45 +169,64 @@ class TestCache:
reason="More time-consuming since it deploys the app and loads models.", reason="More time-consuming since it deploys the app and loads models.",
) )
class TestEndpoints: class TestEndpoints:
def test_tagging_endpoint(self, pil_image: Image.Image, deployed_app: TestClient) -> None: def test_tagging_endpoint(
self, pil_image: Image.Image, responses: dict[str, Any], deployed_app: TestClient
) -> None:
byte_image = BytesIO() byte_image = BytesIO()
pil_image.save(byte_image, format="jpeg") pil_image.save(byte_image, format="jpeg")
headers = {"Content-Type": "image/jpg"}
response = deployed_app.post( response = deployed_app.post(
"http://localhost:3003/image-classifier/tag-image", "http://localhost:3003/predict",
content=byte_image.getvalue(), data={
headers=headers, "modelName": "microsoft/resnet-50",
"modelType": "image-classification",
"options": json.dumps({"minScore": 0.0}),
},
files={"image": byte_image.getvalue()},
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == responses["image-classification"]
def test_clip_image_endpoint(self, pil_image: Image.Image, deployed_app: TestClient) -> None: def test_clip_image_endpoint(
self, pil_image: Image.Image, responses: dict[str, Any], deployed_app: TestClient
) -> None:
byte_image = BytesIO() byte_image = BytesIO()
pil_image.save(byte_image, format="jpeg") pil_image.save(byte_image, format="jpeg")
headers = {"Content-Type": "image/jpg"}
response = deployed_app.post( response = deployed_app.post(
"http://localhost:3003/sentence-transformer/encode-image", "http://localhost:3003/predict",
content=byte_image.getvalue(), data={"modelName": "ViT-B-32::openai", "modelType": "clip", "options": json.dumps({"mode": "vision"})},
headers=headers, files={"image": byte_image.getvalue()},
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == responses["clip"]["image"]
def test_clip_text_endpoint(self, deployed_app: TestClient) -> None: def test_clip_text_endpoint(self, responses: dict[str, Any], deployed_app: TestClient) -> None:
response = deployed_app.post( response = deployed_app.post(
"http://localhost:3003/sentence-transformer/encode-text", "http://localhost:3003/predict",
json={"text": "test search query"}, data={
"modelName": "ViT-B-32::openai",
"modelType": "clip",
"text": "test search query",
"options": json.dumps({"mode": "text"}),
},
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == responses["clip"]["text"]
def test_face_endpoint(self, pil_image: Image.Image, deployed_app: TestClient) -> None: def test_face_endpoint(self, pil_image: Image.Image, responses: dict[str, Any], deployed_app: TestClient) -> None:
byte_image = BytesIO() byte_image = BytesIO()
pil_image.save(byte_image, format="jpeg") pil_image.save(byte_image, format="jpeg")
headers = {"Content-Type": "image/jpg"}
response = deployed_app.post( response = deployed_app.post(
"http://localhost:3003/facial-recognition/detect-faces", "http://localhost:3003/predict",
content=byte_image.getvalue(), data={
headers=headers, "modelName": "buffalo_l",
"modelType": "facial-recognition",
"options": json.dumps({"minScore": 0.034}),
},
files={"image": byte_image.getvalue()},
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == responses["facial-recognition"]
def test_sess_options() -> None: def test_sess_options() -> None:

File diff suppressed because it is too large Load Diff