introduced estimator-adapter and estimator-adapter-patch

This commit is contained in:
Matthias Bisping 2022-03-25 13:35:03 +01:00
parent 9c9070e8bf
commit 8b7293be09
10 changed files with 74 additions and 40 deletions

View File

@ -0,0 +1,10 @@
import abc
class EstimatorAdapter(abc.ABC):
def __init__(self, estimator):
self.estimator = estimator
@abc.abstractmethod
def predict(self, batch):
pass

View File

@ -0,0 +1,20 @@
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.estimator.estimator import Estimator
class EstimatorAdapterPatch(EstimatorAdapter):
def __init__(self, estimator: Estimator):
super().__init__(estimator=estimator)
self.__output_batch = None
@property
def output_batch(self):
return self.__output_batch
@output_batch.setter
def output_batch(self, output_batch):
self.__output_batch = output_batch
def predict(self, batch):
self.estimator.predict(batch)
return self.__output_batch

View File

@ -0,0 +1,10 @@
import abc
class Estimator(abc.ABC):
def __init__(self, estimator):
self.estimator = estimator
@abc.abstractmethod
def predict(self, batch):
pass

View File

@ -0,0 +1,10 @@
from image_prediction.estimator.estimator import Estimator
class EstimatorMock(Estimator):
def __init__(self, estimator=lambda x: x):
super().__init__(estimator=estimator)
def predict(self, batch):
return self.estimator(batch)

View File

@ -1,14 +0,0 @@
class EstimatorMock:
def __init__(self):
self.__output_batch = None
@property
def output_batch(self):
return self.__output_batch
@output_batch.setter
def output_batch(self, output_batch):
self.__output_batch = output_batch
def predict(self, batch):
return self.__output_batch

View File

@ -1,6 +0,0 @@
from image_prediction.service_estimator.service_estimator import ServiceEstimator
class ServiceEstimatorMock(ServiceEstimator):
def __init__(self, estimator, classes):
super().__init__(estimator=estimator, classes=classes)

View File

@ -1,14 +1,12 @@
import abc from typing import Mapping, List
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
class ServiceEstimator(abc.ABC): class ServiceEstimator:
def __init__(self, estimator, classes): def __init__(self, estimator_adapter: EstimatorAdapter, classes: Mapping[int, str]):
self.__estimator = estimator self.__estimator_adapter = estimator_adapter
self.__classes = classes self.__classes = classes
@property def predict(self, batch) -> List[str]:
def estimator(self): return [self.__classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)]
return self.__estimator
def predict(self, batch):
return [self.__classes[numeric_label] for numeric_label in self.estimator.predict(batch)]

View File

@ -3,8 +3,9 @@ import logging
import numpy as np import numpy as np
import pytest import pytest
from image_prediction.estimator.mock import EstimatorMock from image_prediction.estimator.estimators.mock import EstimatorMock
from image_prediction.service_estimator.mock import ServiceEstimatorMock from image_prediction.estimator.adapter.patch import EstimatorAdapterPatch
from image_prediction.service_estimator.service_estimator import ServiceEstimator
from image_prediction.utils import get_logger from image_prediction.utils import get_logger
logger = get_logger() logger = get_logger()
@ -12,10 +13,16 @@ logger.setLevel(logging.DEBUG)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def estimator(output_batch): def estimator(estimator_type):
estimator = EstimatorMock() if estimator_type == "mock":
estimator.output_batch = output_batch return EstimatorMock()
return estimator
@pytest.fixture(scope="session")
def estimator_adapter(output_batch, estimator):
estimator_adapter = EstimatorAdapterPatch(estimator)
estimator_adapter.output_batch = output_batch
return estimator_adapter
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -43,12 +50,11 @@ def map_labels(numeric_labels, classes):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def service_estimator(model_type, estimator, classes): def service_estimator(estimator_adapter, classes):
if model_type == "mock": return ServiceEstimator(estimator_adapter, classes)
return ServiceEstimatorMock(estimator, classes)
@pytest.mark.parametrize("model_type", ["mock"], scope="session") @pytest.mark.parametrize("estimator_type", ["mock"], scope="session")
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session") @pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
def test_predict(service_estimator, input_batch, expected_predictions): def test_predict(service_estimator, input_batch, expected_predictions):
predictions = service_estimator.predict(input_batch) predictions = service_estimator.predict(input_batch)