introduced estimator-adapter and estimator-adapter-patch
This commit is contained in:
parent
9c9070e8bf
commit
8b7293be09
0
image_prediction/estimator/adapter/__init__.py
Normal file
0
image_prediction/estimator/adapter/__init__.py
Normal file
10
image_prediction/estimator/adapter/adapter.py
Normal file
10
image_prediction/estimator/adapter/adapter.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
import abc
|
||||||
|
|
||||||
|
|
||||||
|
class EstimatorAdapter(abc.ABC):
|
||||||
|
def __init__(self, estimator):
|
||||||
|
self.estimator = estimator
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def predict(self, batch):
|
||||||
|
pass
|
||||||
20
image_prediction/estimator/adapter/patch.py
Normal file
20
image_prediction/estimator/adapter/patch.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||||
|
from image_prediction.estimator.estimator import Estimator
|
||||||
|
|
||||||
|
|
||||||
|
class EstimatorAdapterPatch(EstimatorAdapter):
|
||||||
|
def __init__(self, estimator: Estimator):
|
||||||
|
super().__init__(estimator=estimator)
|
||||||
|
self.__output_batch = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_batch(self):
|
||||||
|
return self.__output_batch
|
||||||
|
|
||||||
|
@output_batch.setter
|
||||||
|
def output_batch(self, output_batch):
|
||||||
|
self.__output_batch = output_batch
|
||||||
|
|
||||||
|
def predict(self, batch):
|
||||||
|
self.estimator.predict(batch)
|
||||||
|
return self.__output_batch
|
||||||
10
image_prediction/estimator/estimator.py
Normal file
10
image_prediction/estimator/estimator.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
import abc
|
||||||
|
|
||||||
|
|
||||||
|
class Estimator(abc.ABC):
|
||||||
|
def __init__(self, estimator):
|
||||||
|
self.estimator = estimator
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def predict(self, batch):
|
||||||
|
pass
|
||||||
0
image_prediction/estimator/estimators/__init__.py
Normal file
0
image_prediction/estimator/estimators/__init__.py
Normal file
10
image_prediction/estimator/estimators/mock.py
Normal file
10
image_prediction/estimator/estimators/mock.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from image_prediction.estimator.estimator import Estimator
|
||||||
|
|
||||||
|
|
||||||
|
class EstimatorMock(Estimator):
|
||||||
|
|
||||||
|
def __init__(self, estimator=lambda x: x):
|
||||||
|
super().__init__(estimator=estimator)
|
||||||
|
|
||||||
|
def predict(self, batch):
|
||||||
|
return self.estimator(batch)
|
||||||
@ -1,14 +0,0 @@
|
|||||||
class EstimatorMock:
|
|
||||||
def __init__(self):
|
|
||||||
self.__output_batch = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_batch(self):
|
|
||||||
return self.__output_batch
|
|
||||||
|
|
||||||
@output_batch.setter
|
|
||||||
def output_batch(self, output_batch):
|
|
||||||
self.__output_batch = output_batch
|
|
||||||
|
|
||||||
def predict(self, batch):
|
|
||||||
return self.__output_batch
|
|
||||||
@ -1,6 +0,0 @@
|
|||||||
from image_prediction.service_estimator.service_estimator import ServiceEstimator
|
|
||||||
|
|
||||||
|
|
||||||
class ServiceEstimatorMock(ServiceEstimator):
|
|
||||||
def __init__(self, estimator, classes):
|
|
||||||
super().__init__(estimator=estimator, classes=classes)
|
|
||||||
@ -1,14 +1,12 @@
|
|||||||
import abc
|
from typing import Mapping, List
|
||||||
|
|
||||||
|
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||||
|
|
||||||
|
|
||||||
class ServiceEstimator(abc.ABC):
|
class ServiceEstimator:
|
||||||
def __init__(self, estimator, classes):
|
def __init__(self, estimator_adapter: EstimatorAdapter, classes: Mapping[int, str]):
|
||||||
self.__estimator = estimator
|
self.__estimator_adapter = estimator_adapter
|
||||||
self.__classes = classes
|
self.__classes = classes
|
||||||
|
|
||||||
@property
|
def predict(self, batch) -> List[str]:
|
||||||
def estimator(self):
|
return [self.__classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)]
|
||||||
return self.__estimator
|
|
||||||
|
|
||||||
def predict(self, batch):
|
|
||||||
return [self.__classes[numeric_label] for numeric_label in self.estimator.predict(batch)]
|
|
||||||
|
|||||||
@ -3,8 +3,9 @@ import logging
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from image_prediction.estimator.mock import EstimatorMock
|
from image_prediction.estimator.estimators.mock import EstimatorMock
|
||||||
from image_prediction.service_estimator.mock import ServiceEstimatorMock
|
from image_prediction.estimator.adapter.patch import EstimatorAdapterPatch
|
||||||
|
from image_prediction.service_estimator.service_estimator import ServiceEstimator
|
||||||
from image_prediction.utils import get_logger
|
from image_prediction.utils import get_logger
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
@ -12,10 +13,16 @@ logger.setLevel(logging.DEBUG)
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def estimator(output_batch):
|
def estimator(estimator_type):
|
||||||
estimator = EstimatorMock()
|
if estimator_type == "mock":
|
||||||
estimator.output_batch = output_batch
|
return EstimatorMock()
|
||||||
return estimator
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def estimator_adapter(output_batch, estimator):
|
||||||
|
estimator_adapter = EstimatorAdapterPatch(estimator)
|
||||||
|
estimator_adapter.output_batch = output_batch
|
||||||
|
return estimator_adapter
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@ -43,12 +50,11 @@ def map_labels(numeric_labels, classes):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def service_estimator(model_type, estimator, classes):
|
def service_estimator(estimator_adapter, classes):
|
||||||
if model_type == "mock":
|
return ServiceEstimator(estimator_adapter, classes)
|
||||||
return ServiceEstimatorMock(estimator, classes)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_type", ["mock"], scope="session")
|
@pytest.mark.parametrize("estimator_type", ["mock"], scope="session")
|
||||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
|
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
|
||||||
def test_predict(service_estimator, input_batch, expected_predictions):
|
def test_predict(service_estimator, input_batch, expected_predictions):
|
||||||
predictions = service_estimator.predict(input_batch)
|
predictions = service_estimator.predict(input_batch)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user