From 8b7293be090d01846fc9397e45eced35dd6555eb Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Fri, 25 Mar 2022 13:35:03 +0100 Subject: [PATCH] introduced estimator-adapter and estimator-adapter-patch --- .../estimator/adapter/__init__.py | 0 image_prediction/estimator/adapter/adapter.py | 10 +++++++ image_prediction/estimator/adapter/patch.py | 20 ++++++++++++++ image_prediction/estimator/estimator.py | 10 +++++++ .../estimator/estimators/__init__.py | 0 image_prediction/estimator/estimators/mock.py | 10 +++++++ image_prediction/estimator/mock.py | 14 ---------- image_prediction/service_estimator/mock.py | 6 ----- .../service_estimator/service_estimator.py | 18 ++++++------- test/unit_tests/service_estimator_test.py | 26 ++++++++++++------- 10 files changed, 74 insertions(+), 40 deletions(-) create mode 100644 image_prediction/estimator/adapter/__init__.py create mode 100644 image_prediction/estimator/adapter/adapter.py create mode 100644 image_prediction/estimator/adapter/patch.py create mode 100644 image_prediction/estimator/estimator.py create mode 100644 image_prediction/estimator/estimators/__init__.py create mode 100644 image_prediction/estimator/estimators/mock.py delete mode 100644 image_prediction/estimator/mock.py delete mode 100644 image_prediction/service_estimator/mock.py diff --git a/image_prediction/estimator/adapter/__init__.py b/image_prediction/estimator/adapter/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/estimator/adapter/adapter.py b/image_prediction/estimator/adapter/adapter.py new file mode 100644 index 0000000..8aef944 --- /dev/null +++ b/image_prediction/estimator/adapter/adapter.py @@ -0,0 +1,10 @@ +import abc + + +class EstimatorAdapter(abc.ABC): + def __init__(self, estimator): + self.estimator = estimator + + @abc.abstractmethod + def predict(self, batch): + pass diff --git a/image_prediction/estimator/adapter/patch.py b/image_prediction/estimator/adapter/patch.py new file mode 100644 index 0000000..6385d00 --- /dev/null +++ b/image_prediction/estimator/adapter/patch.py @@ -0,0 +1,20 @@ +from image_prediction.estimator.adapter.adapter import EstimatorAdapter +from image_prediction.estimator.estimator import Estimator + + +class EstimatorAdapterPatch(EstimatorAdapter): + def __init__(self, estimator: Estimator): + super().__init__(estimator=estimator) + self.__output_batch = None + + @property + def output_batch(self): + return self.__output_batch + + @output_batch.setter + def output_batch(self, output_batch): + self.__output_batch = output_batch + + def predict(self, batch): + self.estimator.predict(batch) + return self.__output_batch diff --git a/image_prediction/estimator/estimator.py b/image_prediction/estimator/estimator.py new file mode 100644 index 0000000..e1b6d96 --- /dev/null +++ b/image_prediction/estimator/estimator.py @@ -0,0 +1,10 @@ +import abc + + +class Estimator(abc.ABC): + def __init__(self, estimator): + self.estimator = estimator + + @abc.abstractmethod + def predict(self, batch): + pass diff --git a/image_prediction/estimator/estimators/__init__.py b/image_prediction/estimator/estimators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/estimator/estimators/mock.py b/image_prediction/estimator/estimators/mock.py new file mode 100644 index 0000000..3076f48 --- /dev/null +++ b/image_prediction/estimator/estimators/mock.py @@ -0,0 +1,10 @@ +from image_prediction.estimator.estimator import Estimator + + +class EstimatorMock(Estimator): + + def __init__(self, estimator=lambda x: x): + super().__init__(estimator=estimator) + + def predict(self, batch): + return self.estimator(batch) diff --git a/image_prediction/estimator/mock.py b/image_prediction/estimator/mock.py deleted file mode 100644 index 0522a6e..0000000 --- a/image_prediction/estimator/mock.py +++ /dev/null @@ -1,14 +0,0 @@ -class EstimatorMock: - def __init__(self): - self.__output_batch = None - - @property - def output_batch(self): - return self.__output_batch - - @output_batch.setter - def output_batch(self, output_batch): - self.__output_batch = output_batch - - def predict(self, batch): - return self.__output_batch diff --git a/image_prediction/service_estimator/mock.py b/image_prediction/service_estimator/mock.py deleted file mode 100644 index 815a2f1..0000000 --- a/image_prediction/service_estimator/mock.py +++ /dev/null @@ -1,6 +0,0 @@ -from image_prediction.service_estimator.service_estimator import ServiceEstimator - - -class ServiceEstimatorMock(ServiceEstimator): - def __init__(self, estimator, classes): - super().__init__(estimator=estimator, classes=classes) diff --git a/image_prediction/service_estimator/service_estimator.py b/image_prediction/service_estimator/service_estimator.py index 1c2b2fe..f779b85 100644 --- a/image_prediction/service_estimator/service_estimator.py +++ b/image_prediction/service_estimator/service_estimator.py @@ -1,14 +1,12 @@ -import abc +from typing import Mapping, List + +from image_prediction.estimator.adapter.adapter import EstimatorAdapter -class ServiceEstimator(abc.ABC): - def __init__(self, estimator, classes): - self.__estimator = estimator +class ServiceEstimator: + def __init__(self, estimator_adapter: EstimatorAdapter, classes: Mapping[int, str]): + self.__estimator_adapter = estimator_adapter self.__classes = classes - @property - def estimator(self): - return self.__estimator - - def predict(self, batch): - return [self.__classes[numeric_label] for numeric_label in self.estimator.predict(batch)] + def predict(self, batch) -> List[str]: + return [self.__classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)] diff --git a/test/unit_tests/service_estimator_test.py b/test/unit_tests/service_estimator_test.py index 3609505..ed309ed 100644 --- a/test/unit_tests/service_estimator_test.py +++ b/test/unit_tests/service_estimator_test.py @@ -3,8 +3,9 @@ import logging import numpy as np import pytest -from image_prediction.estimator.mock import EstimatorMock -from image_prediction.service_estimator.mock import ServiceEstimatorMock +from image_prediction.estimator.estimators.mock import EstimatorMock +from image_prediction.estimator.adapter.patch import EstimatorAdapterPatch +from image_prediction.service_estimator.service_estimator import ServiceEstimator from image_prediction.utils import get_logger logger = get_logger() @@ -12,10 +13,16 @@ logger.setLevel(logging.DEBUG) @pytest.fixture(scope="session") -def estimator(output_batch): - estimator = EstimatorMock() - estimator.output_batch = output_batch - return estimator +def estimator(estimator_type): + if estimator_type == "mock": + return EstimatorMock() + + +@pytest.fixture(scope="session") +def estimator_adapter(output_batch, estimator): + estimator_adapter = EstimatorAdapterPatch(estimator) + estimator_adapter.output_batch = output_batch + return estimator_adapter @pytest.fixture(scope="session") @@ -43,12 +50,11 @@ def map_labels(numeric_labels, classes): @pytest.fixture(scope="session") -def service_estimator(model_type, estimator, classes): - if model_type == "mock": - return ServiceEstimatorMock(estimator, classes) +def service_estimator(estimator_adapter, classes): + return ServiceEstimator(estimator_adapter, classes) -@pytest.mark.parametrize("model_type", ["mock"], scope="session") +@pytest.mark.parametrize("estimator_type", ["mock"], scope="session") @pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session") def test_predict(service_estimator, input_batch, expected_predictions): predictions = service_estimator.predict(input_batch)