forked from Cutlery/immich
		
	* added testing * github action for python, made mypy happy * formatted with black * minor fixes and styling * test model cache * cache test dependencies * narrowed model cache tests * moved endpoint tests to their own class * cleaned up fixtures * formatting * removed unused dep
		
			
				
	
	
		
			120 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			120 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from types import SimpleNamespace
 | 
						|
from typing import Any, Iterator, TypeAlias
 | 
						|
from unittest import mock
 | 
						|
 | 
						|
import numpy as np
 | 
						|
import pytest
 | 
						|
from fastapi.testclient import TestClient
 | 
						|
from PIL import Image
 | 
						|
 | 
						|
from .main import app, init_state
 | 
						|
 | 
						|
ndarray: TypeAlias = np.ndarray[int, np.dtype[np.float32]]
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture
 | 
						|
def pil_image() -> Image.Image:
 | 
						|
    return Image.new("RGB", (600, 800))
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture
 | 
						|
def cv_image(pil_image: Image.Image) -> ndarray:
 | 
						|
    return np.asarray(pil_image)[:, :, ::-1]  # PIL uses RGB while cv2 uses BGR
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture
 | 
						|
def mock_classifier_pipeline() -> Iterator[mock.Mock]:
 | 
						|
    with mock.patch("app.models.image_classification.pipeline") as model:
 | 
						|
        classifier_preds = [
 | 
						|
            {"label": "that's an image alright", "score": 0.8},
 | 
						|
            {"label": "well it ends with .jpg", "score": 0.1},
 | 
						|
            {"label": "idk, im just seeing bytes", "score": 0.05},
 | 
						|
            {"label": "not sure", "score": 0.04},
 | 
						|
            {"label": "probably a virus", "score": 0.01},
 | 
						|
        ]
 | 
						|
 | 
						|
        def forward(
 | 
						|
            inputs: Image.Image | list[Image.Image], **kwargs: Any
 | 
						|
        ) -> list[dict[str, Any]] | list[list[dict[str, Any]]]:
 | 
						|
            if isinstance(inputs, list) and not all([isinstance(img, Image.Image) for img in inputs]):
 | 
						|
                raise TypeError
 | 
						|
            elif not isinstance(inputs, Image.Image):
 | 
						|
                raise TypeError
 | 
						|
 | 
						|
            if isinstance(inputs, list):
 | 
						|
                return [classifier_preds] * len(inputs)
 | 
						|
 | 
						|
            return classifier_preds
 | 
						|
 | 
						|
        model.return_value = forward
 | 
						|
        yield model
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture
 | 
						|
def mock_st() -> Iterator[mock.Mock]:
 | 
						|
    with mock.patch("app.models.clip.SentenceTransformer") as model:
 | 
						|
        embedding = np.random.rand(512).astype(np.float32)
 | 
						|
 | 
						|
        def encode(inputs: Image.Image | list[Image.Image], **kwargs: Any) -> ndarray | list[ndarray]:
 | 
						|
            #  mypy complains unless isinstance(inputs, list) is used explicitly
 | 
						|
            img_batch = isinstance(inputs, list) and all([isinstance(inst, Image.Image) for inst in inputs])
 | 
						|
            text_batch = isinstance(inputs, list) and all([isinstance(inst, str) for inst in inputs])
 | 
						|
            if isinstance(inputs, list) and not any([img_batch, text_batch]):
 | 
						|
                raise TypeError
 | 
						|
 | 
						|
            if isinstance(inputs, list):
 | 
						|
                return np.stack([embedding] * len(inputs))
 | 
						|
 | 
						|
            return embedding
 | 
						|
 | 
						|
        mocked = mock.Mock()
 | 
						|
        mocked.encode = encode
 | 
						|
        model.return_value = mocked
 | 
						|
        yield model
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture
 | 
						|
def mock_faceanalysis() -> Iterator[mock.Mock]:
 | 
						|
    with mock.patch("app.models.facial_recognition.FaceAnalysis") as model:
 | 
						|
        face_preds = [
 | 
						|
            SimpleNamespace(  # this is so these fields can be accessed through dot notation
 | 
						|
                **{
 | 
						|
                    "bbox": np.random.rand(4).astype(np.float32),
 | 
						|
                    "kps": np.random.rand(5, 2).astype(np.float32),
 | 
						|
                    "det_score": np.array([0.67]).astype(np.float32),
 | 
						|
                    "normed_embedding": np.random.rand(512).astype(np.float32),
 | 
						|
                }
 | 
						|
            ),
 | 
						|
            SimpleNamespace(
 | 
						|
                **{
 | 
						|
                    "bbox": np.random.rand(4).astype(np.float32),
 | 
						|
                    "kps": np.random.rand(5, 2).astype(np.float32),
 | 
						|
                    "det_score": np.array([0.4]).astype(np.float32),
 | 
						|
                    "normed_embedding": np.random.rand(512).astype(np.float32),
 | 
						|
                }
 | 
						|
            ),
 | 
						|
        ]
 | 
						|
 | 
						|
        def get(image: np.ndarray[int, np.dtype[np.float32]], **kwargs: Any) -> list[SimpleNamespace]:
 | 
						|
            if not isinstance(image, np.ndarray):
 | 
						|
                raise TypeError
 | 
						|
 | 
						|
            return face_preds
 | 
						|
 | 
						|
        mocked = mock.Mock()
 | 
						|
        mocked.get = get
 | 
						|
        model.return_value = mocked
 | 
						|
        yield model
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture
 | 
						|
def mock_get_model() -> Iterator[mock.Mock]:
 | 
						|
    with mock.patch("app.models.cache.InferenceModel.from_model_type", autospec=True) as mocked:
 | 
						|
        yield mocked
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture(scope="session")
 | 
						|
def deployed_app() -> TestClient:
 | 
						|
    init_state()
 | 
						|
    return TestClient(app)
 |