introduced estimator-adapter and estimator-adapter-patch
This commit is contained in:
parent
9c9070e8bf
commit
8b7293be09
0
image_prediction/estimator/adapter/__init__.py
Normal file
0
image_prediction/estimator/adapter/__init__.py
Normal file
10
image_prediction/estimator/adapter/adapter.py
Normal file
10
image_prediction/estimator/adapter/adapter.py
Normal file
@ -0,0 +1,10 @@
|
||||
import abc
|
||||
|
||||
|
||||
class EstimatorAdapter(abc.ABC):
|
||||
def __init__(self, estimator):
|
||||
self.estimator = estimator
|
||||
|
||||
@abc.abstractmethod
|
||||
def predict(self, batch):
|
||||
pass
|
||||
20
image_prediction/estimator/adapter/patch.py
Normal file
20
image_prediction/estimator/adapter/patch.py
Normal file
@ -0,0 +1,20 @@
|
||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||
from image_prediction.estimator.estimator import Estimator
|
||||
|
||||
|
||||
class EstimatorAdapterPatch(EstimatorAdapter):
|
||||
def __init__(self, estimator: Estimator):
|
||||
super().__init__(estimator=estimator)
|
||||
self.__output_batch = None
|
||||
|
||||
@property
|
||||
def output_batch(self):
|
||||
return self.__output_batch
|
||||
|
||||
@output_batch.setter
|
||||
def output_batch(self, output_batch):
|
||||
self.__output_batch = output_batch
|
||||
|
||||
def predict(self, batch):
|
||||
self.estimator.predict(batch)
|
||||
return self.__output_batch
|
||||
10
image_prediction/estimator/estimator.py
Normal file
10
image_prediction/estimator/estimator.py
Normal file
@ -0,0 +1,10 @@
|
||||
import abc
|
||||
|
||||
|
||||
class Estimator(abc.ABC):
|
||||
def __init__(self, estimator):
|
||||
self.estimator = estimator
|
||||
|
||||
@abc.abstractmethod
|
||||
def predict(self, batch):
|
||||
pass
|
||||
0
image_prediction/estimator/estimators/__init__.py
Normal file
0
image_prediction/estimator/estimators/__init__.py
Normal file
10
image_prediction/estimator/estimators/mock.py
Normal file
10
image_prediction/estimator/estimators/mock.py
Normal file
@ -0,0 +1,10 @@
|
||||
from image_prediction.estimator.estimator import Estimator
|
||||
|
||||
|
||||
class EstimatorMock(Estimator):
|
||||
|
||||
def __init__(self, estimator=lambda x: x):
|
||||
super().__init__(estimator=estimator)
|
||||
|
||||
def predict(self, batch):
|
||||
return self.estimator(batch)
|
||||
@ -1,14 +0,0 @@
|
||||
class EstimatorMock:
|
||||
def __init__(self):
|
||||
self.__output_batch = None
|
||||
|
||||
@property
|
||||
def output_batch(self):
|
||||
return self.__output_batch
|
||||
|
||||
@output_batch.setter
|
||||
def output_batch(self, output_batch):
|
||||
self.__output_batch = output_batch
|
||||
|
||||
def predict(self, batch):
|
||||
return self.__output_batch
|
||||
@ -1,6 +0,0 @@
|
||||
from image_prediction.service_estimator.service_estimator import ServiceEstimator
|
||||
|
||||
|
||||
class ServiceEstimatorMock(ServiceEstimator):
|
||||
def __init__(self, estimator, classes):
|
||||
super().__init__(estimator=estimator, classes=classes)
|
||||
@ -1,14 +1,12 @@
|
||||
import abc
|
||||
from typing import Mapping, List
|
||||
|
||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||
|
||||
|
||||
class ServiceEstimator(abc.ABC):
|
||||
def __init__(self, estimator, classes):
|
||||
self.__estimator = estimator
|
||||
class ServiceEstimator:
|
||||
def __init__(self, estimator_adapter: EstimatorAdapter, classes: Mapping[int, str]):
|
||||
self.__estimator_adapter = estimator_adapter
|
||||
self.__classes = classes
|
||||
|
||||
@property
|
||||
def estimator(self):
|
||||
return self.__estimator
|
||||
|
||||
def predict(self, batch):
|
||||
return [self.__classes[numeric_label] for numeric_label in self.estimator.predict(batch)]
|
||||
def predict(self, batch) -> List[str]:
|
||||
return [self.__classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)]
|
||||
|
||||
@ -3,8 +3,9 @@ import logging
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from image_prediction.estimator.mock import EstimatorMock
|
||||
from image_prediction.service_estimator.mock import ServiceEstimatorMock
|
||||
from image_prediction.estimator.estimators.mock import EstimatorMock
|
||||
from image_prediction.estimator.adapter.patch import EstimatorAdapterPatch
|
||||
from image_prediction.service_estimator.service_estimator import ServiceEstimator
|
||||
from image_prediction.utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
@ -12,10 +13,16 @@ logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def estimator(output_batch):
|
||||
estimator = EstimatorMock()
|
||||
estimator.output_batch = output_batch
|
||||
return estimator
|
||||
def estimator(estimator_type):
|
||||
if estimator_type == "mock":
|
||||
return EstimatorMock()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def estimator_adapter(output_batch, estimator):
|
||||
estimator_adapter = EstimatorAdapterPatch(estimator)
|
||||
estimator_adapter.output_batch = output_batch
|
||||
return estimator_adapter
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@ -43,12 +50,11 @@ def map_labels(numeric_labels, classes):
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def service_estimator(model_type, estimator, classes):
|
||||
if model_type == "mock":
|
||||
return ServiceEstimatorMock(estimator, classes)
|
||||
def service_estimator(estimator_adapter, classes):
|
||||
return ServiceEstimator(estimator_adapter, classes)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_type", ["mock"], scope="session")
|
||||
@pytest.mark.parametrize("estimator_type", ["mock"], scope="session")
|
||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
|
||||
def test_predict(service_estimator, input_batch, expected_predictions):
|
||||
predictions = service_estimator.predict(input_batch)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user