diff --git a/machine-learning/immich_ml/main.py b/machine-learning/immich_ml/main.py index e7e3a719bb..4fca7a2e2b 100644 --- a/machine-learning/immich_ml/main.py +++ b/machine-learning/immich_ml/main.py @@ -183,7 +183,10 @@ async def predict( text: str | None = Form(default=None), ) -> Any: if image is not None: - inputs: Image | str = await run(lambda: decode_pil(image)) + decoded = await run(lambda: decode_pil(image)) + if decoded.width == 0 or decoded.height == 0: + raise HTTPException(400, "Image has zero width or height") + inputs: Image | str = decoded elif text is not None: inputs = text else: diff --git a/machine-learning/test_main.py b/machine-learning/test_main.py index 0182c57c67..cce334e40e 100644 --- a/machine-learning/test_main.py +++ b/machine-learning/test_main.py @@ -1198,6 +1198,19 @@ class TestLoad: mock_model.model_format = ModelFormat.ONNX +@pytest.mark.parametrize("size", [(0, 100), (100, 0), (0, 0)]) +def test_predict_rejects_empty_image(size: tuple[int, int], deployed_app: TestClient) -> None: + with mock.patch("immich_ml.main.decode_pil", return_value=Image.new("RGB", size)): + response = deployed_app.post( + "http://localhost:3003/predict", + data={"entries": json.dumps({"clip": {"visual": {"modelName": "ViT-B-32__openai"}}})}, + files={"image": b"fake image bytes"}, + ) + + assert response.status_code == 400 + assert "zero" in response.json()["detail"].lower() + + def test_root_endpoint(deployed_app: TestClient) -> None: response = deployed_app.get("http://localhost:3003")