introduced estimator-adapter and estimator-adapter-patch

This commit is contained in:
Matthias Bisping 2022-03-25 13:35:03 +01:00
parent 9c9070e8bf
commit 8b7293be09
10 changed files with 74 additions and 40 deletions

View File

@ -0,0 +1,10 @@
import abc
class EstimatorAdapter(abc.ABC):
def __init__(self, estimator):
self.estimator = estimator
@abc.abstractmethod
def predict(self, batch):
pass

View File

@ -0,0 +1,20 @@
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.estimator.estimator import Estimator
class EstimatorAdapterPatch(EstimatorAdapter):
def __init__(self, estimator: Estimator):
super().__init__(estimator=estimator)
self.__output_batch = None
@property
def output_batch(self):
return self.__output_batch
@output_batch.setter
def output_batch(self, output_batch):
self.__output_batch = output_batch
def predict(self, batch):
self.estimator.predict(batch)
return self.__output_batch

View File

@ -0,0 +1,10 @@
import abc
class Estimator(abc.ABC):
def __init__(self, estimator):
self.estimator = estimator
@abc.abstractmethod
def predict(self, batch):
pass

View File

@ -0,0 +1,10 @@
from image_prediction.estimator.estimator import Estimator
class EstimatorMock(Estimator):
def __init__(self, estimator=lambda x: x):
super().__init__(estimator=estimator)
def predict(self, batch):
return self.estimator(batch)

View File

@ -1,14 +0,0 @@
class EstimatorMock:
def __init__(self):
self.__output_batch = None
@property
def output_batch(self):
return self.__output_batch
@output_batch.setter
def output_batch(self, output_batch):
self.__output_batch = output_batch
def predict(self, batch):
return self.__output_batch

View File

@ -1,6 +0,0 @@
from image_prediction.service_estimator.service_estimator import ServiceEstimator
class ServiceEstimatorMock(ServiceEstimator):
def __init__(self, estimator, classes):
super().__init__(estimator=estimator, classes=classes)

View File

@ -1,14 +1,12 @@
import abc
from typing import Mapping, List
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
class ServiceEstimator(abc.ABC):
def __init__(self, estimator, classes):
self.__estimator = estimator
class ServiceEstimator:
def __init__(self, estimator_adapter: EstimatorAdapter, classes: Mapping[int, str]):
self.__estimator_adapter = estimator_adapter
self.__classes = classes
@property
def estimator(self):
return self.__estimator
def predict(self, batch):
return [self.__classes[numeric_label] for numeric_label in self.estimator.predict(batch)]
def predict(self, batch) -> List[str]:
return [self.__classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)]

View File

@ -3,8 +3,9 @@ import logging
import numpy as np
import pytest
from image_prediction.estimator.mock import EstimatorMock
from image_prediction.service_estimator.mock import ServiceEstimatorMock
from image_prediction.estimator.estimators.mock import EstimatorMock
from image_prediction.estimator.adapter.patch import EstimatorAdapterPatch
from image_prediction.service_estimator.service_estimator import ServiceEstimator
from image_prediction.utils import get_logger
logger = get_logger()
@ -12,10 +13,16 @@ logger.setLevel(logging.DEBUG)
@pytest.fixture(scope="session")
def estimator(output_batch):
estimator = EstimatorMock()
estimator.output_batch = output_batch
return estimator
def estimator(estimator_type):
if estimator_type == "mock":
return EstimatorMock()
@pytest.fixture(scope="session")
def estimator_adapter(output_batch, estimator):
estimator_adapter = EstimatorAdapterPatch(estimator)
estimator_adapter.output_batch = output_batch
return estimator_adapter
@pytest.fixture(scope="session")
@ -43,12 +50,11 @@ def map_labels(numeric_labels, classes):
@pytest.fixture(scope="session")
def service_estimator(model_type, estimator, classes):
if model_type == "mock":
return ServiceEstimatorMock(estimator, classes)
def service_estimator(estimator_adapter, classes):
return ServiceEstimator(estimator_adapter, classes)
@pytest.mark.parametrize("model_type", ["mock"], scope="session")
@pytest.mark.parametrize("estimator_type", ["mock"], scope="session")
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
def test_predict(service_estimator, input_batch, expected_predictions):
predictions = service_estimator.predict(input_batch)