diff --git a/packages/mlopspython-inference/mlopspython_inference/inference_pillow.py b/packages/mlopspython-inference/mlopspython_inference/inference_pillow.py index 63e5a9d3..8031e5ae 100644 --- a/packages/mlopspython-inference/mlopspython_inference/inference_pillow.py +++ b/packages/mlopspython-inference/mlopspython_inference/inference_pillow.py @@ -1,3 +1,4 @@ +import logging from io import BytesIO import numpy as np @@ -23,10 +24,25 @@ def load_image(filename: str|BytesIO): BASE_PATH = Path(__file__).resolve().parent +class IModel(): + def predict(self, img) -> np.ndarray: + pass + +class ModelPillow(IModel): + def __init__(self, model_path: str): + self.model = load_model(model_path) + + def predict(self, img) -> np.ndarray: + return self.model.predict(img) + +class ModelMock(IModel): + def predict(self, img) -> np.ndarray: + return np.array([[1, 0, 0]]) + class Inference: - def __init__(self, logging, model_path: str): + def __init__(self, logging, model: IModel): self.logger = logging.getLogger(__name__) - self.model = load_model(model_path) + self.model = model def execute(self, filepath:str|BytesIO): img = load_image(filepath) diff --git a/packages/mlopspython-inference/tests/input/model/final_model.h5 b/packages/mlopspython-inference/tests/input/model/final_model.h5 deleted file mode 100644 index 60daa135..00000000 Binary files a/packages/mlopspython-inference/tests/input/model/final_model.h5 and /dev/null differ diff --git a/packages/mlopspython-inference/tests/test_inference.py b/packages/mlopspython-inference/tests/test_inference.py index 27cc7656..d9c98a82 100644 --- a/packages/mlopspython-inference/tests/test_inference.py +++ b/packages/mlopspython-inference/tests/test_inference.py @@ -1,22 +1,25 @@ import logging from pathlib import Path +from unittest.mock import MagicMock + import pytest -from mlopspython_inference.inference_pillow import Inference +from mlopspython_inference.inference_pillow import Inference, ModelMock, IModel BASE_PATH = Path(__file__).resolve().parent input_directory = BASE_PATH / "input" -@pytest.mark.skip(reason="Modèle lourd / GPU non requis sur CI. Enlever ce skip si nécessaire.") +#@pytest.mark.skip(reason="Modèle lourd / GPU non requis sur CI. Enlever ce skip si nécessaire.") def test_inference_runs_with_sample_model_and_image(): - model_path = input_directory / "model" / "final_model.h5" image_path = input_directory / "images" / "cat.png" - assert model_path.is_file(), "Modèle de test manquant" assert image_path.is_file(), "Image de test manquante" - inference = Inference(logging, str(model_path)) + model_mock = MagicMock(IModel) + model_mock.execute = MagicMock(return_value=[[1, 0, 0]]) + + inference = Inference(logging, model_mock) result = inference.execute(str(image_path)) assert result["prediction"] in {"Cat", "Dog", "Other"}