diff --git a/image_prediction/estimator/estimator.py b/image_prediction/estimator/estimator.py index 66e6ed1..691743a 100644 --- a/image_prediction/estimator/estimator.py +++ b/image_prediction/estimator/estimator.py @@ -16,7 +16,9 @@ class Estimator: def predict(self, batch: np.array) -> List[str]: if batch.shape[0] == 0: - logger.warning("Estimator received empty batch.") return [] return [self.__classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)] + + def __call__(self, batch: np.array) -> List[str]: + return self.predict(batch) diff --git a/image_prediction/estimator/preprocessor/preprocessor.py b/image_prediction/estimator/preprocessor/preprocessor.py index c7c4ce9..aab8f12 100644 --- a/image_prediction/estimator/preprocessor/preprocessor.py +++ b/image_prediction/estimator/preprocessor/preprocessor.py @@ -1,9 +1,12 @@ from image_prediction.estimator.adapter.adapter import EstimatorAdapter -class EstimatorPreprocessor: - def __init__(self, estimator: EstimatorAdapter): - self.estimator = estimator +class Preprocessor: + def __init__(self): + pass - def predict(self, batch): - return self.estimator.predict(batch) + # def preprocess(self, batch): + # return (map(image_to_normalized_tensor, batch)) + + def __call__(self, batch): + return self.preprocess(batch) diff --git a/image_prediction/estimator/preprocessor/utils.py b/image_prediction/estimator/preprocessor/utils.py new file mode 100644 index 0000000..dbab144 --- /dev/null +++ b/image_prediction/estimator/preprocessor/utils.py @@ -0,0 +1,10 @@ +import numpy as np +from PIL.Image import Image + + +def image_to_normalized_tensor(image: Image) -> np.ndarray: + return np.array(image) / 255 + + +def images_to_batch_tensor(images) -> np.ndarray: + return np.array(list(map(image_to_normalized_tensor, images))) diff --git a/image_prediction/utils.py b/image_prediction/utils.py index 15badca..578fda2 100644 --- a/image_prediction/utils.py +++ b/image_prediction/utils.py @@ -1,8 +1,11 @@ import logging import tempfile from contextlib import contextmanager +from itertools import takewhile, starmap, islice, repeat +from operator import truth from image_prediction.config import CONFIG +from redai.utils import export @contextmanager @@ -66,3 +69,8 @@ def show_banner(): logger.addHandler(handler) logger.info(banner) + + +@export +def chunk_iterable(iterable, chunk_size): + return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size))))) diff --git a/test/unit_tests/conftest.py b/test/unit_tests/conftest.py index d3d247a..460bf83 100644 --- a/test/unit_tests/conftest.py +++ b/test/unit_tests/conftest.py @@ -47,17 +47,26 @@ def keras_model(input_size): import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' - from tensorflow import keras + import tensorflow as tf - inputs = keras.Input(shape=input_size) - dense = keras.layers.Dense(64, activation="relu") - outputs = keras.layers.Dense(10)(dense(inputs)) - model = keras.Model(inputs=inputs, outputs=outputs) + tf.keras.backend.set_image_data_format('channels_last') + + inputs = tf.keras.Input(shape=input_size) + conv = tf.keras.layers.Conv2D(3, 3) + dense = tf.keras.layers.Dense(10, activation="relu") + + outputs = tf.keras.layers.Dense(10)(dense(conv(inputs))) + model = tf.keras.Model(inputs=inputs, outputs=outputs) model.compile() return model +@pytest.fixture +def batch_size(): + return 4 + + @pytest.fixture def images(input_batch): return list(map(array_to_image, input_batch)) @@ -71,12 +80,12 @@ def input_batch(batch_size, input_size): def array_to_image(array): assert np.all(array <= 1) assert np.all(array >= 0) - return Image.fromarray(np.uint8(array * 255)) + return Image.fromarray(np.uint8(array * 255), mode="RGB") @pytest.fixture -def input_size(width=10, height=15): - return width, height +def input_size(depth=3, width=10, height=15): + return width, height, depth @pytest.fixture diff --git a/test/unit_tests/preprocessor_test.py b/test/unit_tests/preprocessor_test.py new file mode 100644 index 0000000..4102c12 --- /dev/null +++ b/test/unit_tests/preprocessor_test.py @@ -0,0 +1,25 @@ +import logging + +import numpy as np +from PIL import Image + +from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor, images_to_batch_tensor +from image_prediction.utils import get_logger + + + + +def test_image_to_tensor(images): + + def inner(image): + tensor = image_to_normalized_tensor(image) + image_re = Image.fromarray(np.uint8(tensor * 255), mode="RGB") + return image == image_re and tensor.ndim == 3 + + assert all(map(inner, images)) + + +def test_images_to_batch_tensor(images): + tensor = images_to_batch_tensor(images) + assert isinstance(tensor, np.ndarray) + assert tensor.ndim == 4 diff --git a/test/unit_tests/service_estimator_test.py b/test/unit_tests/service_estimator_test.py index f561901..347e803 100644 --- a/test/unit_tests/service_estimator_test.py +++ b/test/unit_tests/service_estimator_test.py @@ -13,3 +13,15 @@ logger.setLevel(logging.DEBUG) def test_predict(service_estimator, input_batch, expected_predictions): predictions = service_estimator.predict(input_batch) assert predictions == expected_predictions + + +def test_batch_format(input_batch): + + def channels_are_last(input_batch): + return input_batch.shape[-1] == 3 + + def is_fourth_order_tensor(input_batch): + return input_batch.ndim == 4 + + assert channels_are_last(input_batch) + assert is_fourth_order_tensor(input_batch)