diff --git a/image_prediction/estimator/__init__.py b/image_prediction/estimator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/estimator/mock.py b/image_prediction/estimator/mock.py new file mode 100644 index 0000000..0522a6e --- /dev/null +++ b/image_prediction/estimator/mock.py @@ -0,0 +1,14 @@ +class EstimatorMock: + def __init__(self): + self.__output_batch = None + + @property + def output_batch(self): + return self.__output_batch + + @output_batch.setter + def output_batch(self, output_batch): + self.__output_batch = output_batch + + def predict(self, batch): + return self.__output_batch diff --git a/image_prediction/model/__init__.py b/image_prediction/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/model/mock.py b/image_prediction/model/mock.py new file mode 100644 index 0000000..f52722c --- /dev/null +++ b/image_prediction/model/mock.py @@ -0,0 +1,6 @@ +from image_prediction.model.model import Model + + +class ModelMock(Model): + def __init__(self, estimator): + super().__init__(estimator=estimator) diff --git a/image_prediction/model/model.py b/image_prediction/model/model.py new file mode 100644 index 0000000..72119f4 --- /dev/null +++ b/image_prediction/model/model.py @@ -0,0 +1,13 @@ +import abc + + +class Model(abc.ABC): + def __init__(self, estimator): + self.__estimator = estimator + + @property + def estimator(self): + return self.__estimator + + def predict(self, batch): + return self.estimator.predict(batch) diff --git a/test/conftest.py b/test/conftest.py deleted file mode 100644 index 71b37d1..0000000 --- a/test/conftest.py +++ /dev/null @@ -1,70 +0,0 @@ -import os.path - -import pytest - -from image_prediction.predictor import Predictor - - -@pytest.fixture -def predictions(): - return [ - { - "class": "signature", - "probabilities": { - "signature": 1.0, - "logo": 9.150285377746546e-19, - "other": 4.374506412383356e-19, - "formula": 3.582569597002796e-24, - }, - } - ] - - -@pytest.fixture -def metadata(): - return [ - { - "page_height": 612.0, - "page_width": 792.0, - "height": 61.049999999999955, - "width": 139.35000000000002, - "page_idx": 8, - "x1": 63.5, - "x2": 202.85000000000002, - "y1": 472.0, - "y2": 533.05, - } - ] - - -@pytest.fixture -def response(): - return [ - { - "classification": { - "label": "signature", - "probabilities": {"formula": 0.0, "logo": 0.0, "other": 0.0, "signature": 1.0}, - }, - "filters": { - "allPassed": True, - "geometry": { - "imageFormat": {"quotient": 2.282555282555285, "tooTall": False, "tooWide": False}, - "imageSize": {"quotient": 0.13248234868245012, "tooLarge": False, "tooSmall": False}, - }, - "probability": {"unconfident": False}, - }, - "geometry": {"height": 61.049999999999955, "width": 139.35000000000002}, - "position": {"pageNumber": 9, "x1": 63.5, "x2": 202.85000000000002, "y1": 472.0, "y2": 533.05}, - } - ] - - -@pytest.fixture -def predictor(): - return Predictor() - - -@pytest.fixture -def test_pdf(): - with open("./test/test_data/f2dc689ca794fccb8cd38b95f2bf6ba9.pdf", "rb") as f: - return f.read() diff --git a/test/unit_tests/model_test.py b/test/unit_tests/model_test.py new file mode 100644 index 0000000..2104a97 --- /dev/null +++ b/test/unit_tests/model_test.py @@ -0,0 +1,36 @@ +import numpy as np +import pytest + +from image_prediction.estimator.mock import EstimatorMock +from image_prediction.model.mock import ModelMock + + +@pytest.fixture(scope="session") +def estimator(): + return EstimatorMock() + + +@pytest.fixture(scope="session") +def batches(batch_size): + input_batch = np.random.normal(size=(batch_size, 10, 15)) + output_batch = np.random.randint(low=42, high=43, size=(batch_size, 10, 15)) + return input_batch, output_batch + + +@pytest.fixture(scope="session") +def classes(): + return ["A", "B", "C"] + + +@pytest.fixture(scope="session") +def model(model_type, estimator): + if model_type == "mock": + return ModelMock(estimator) + + +@pytest.mark.parametrize("model_type", ["mock"], scope="session") +@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session") +def test_predict(model, batches): + input_batch, output_batch = batches + model.estimator.output_batch = output_batch + assert np.all(np.equal(model.predict(input_batch), output_batch)) diff --git a/test/unit_tests/test_predictor.py b/test/unit_tests/test_predictor.py deleted file mode 100644 index 0da6f91..0000000 --- a/test/unit_tests/test_predictor.py +++ /dev/null @@ -1,26 +0,0 @@ -def test_predict_pdf_works(predictor, test_pdf): - # FIXME ugly test since there are '\n's in the dict with unknown heritage - predictions, metadata = predictor.predict_pdf(test_pdf) - predictions = [p for p in predictions][0] - assert predictions["class"] == "formula" - probabilities = predictions["probabilities"] - # Floating point precision problem for output so test only that keys exist not the values - assert all(key in probabilities for key in ("formula", "other", "signature", "logo")) - metadata = list(metadata) - metadata = dict(**metadata[0]) - metadata.pop("document_filename") # temp filename cannot be tested - assert metadata == { - "px_width": 389.0, - "px_height": 389.0, - "width": 194.49999000000003, - "height": 194.49998999999997, - "x1": 320.861, - "x2": 515.36099, - "y1": 347.699, - "y2": 542.19899, - "page_width": 595.2800000000001, - "page_height": 841.89, - "page_rotation": 0, - "page_idx": 1, - "n_pages": 3, - } diff --git a/test/unit_tests/test_response.py b/test/unit_tests/test_response.py deleted file mode 100644 index 696c92b..0000000 --- a/test/unit_tests/test_response.py +++ /dev/null @@ -1,5 +0,0 @@ -from image_prediction.response import build_response - - -def test_build_response_returns_valid_response(predictions, metadata, response): - assert build_response(predictions, metadata) == response