diff --git a/image_prediction/extractor_classifier/extractor_classifier.py b/image_prediction/extractor_classifier/extractor_classifier.py index eef74da..9c4c774 100644 --- a/image_prediction/extractor_classifier/extractor_classifier.py +++ b/image_prediction/extractor_classifier/extractor_classifier.py @@ -2,7 +2,7 @@ from itertools import chain from typing import Iterable from image_prediction.classifier.image_classifier import ImageClassifier -from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair +from image_prediction.image_extractor.extractor import ImageExtractor from image_prediction.utils.generic import chunk_iterable @@ -24,7 +24,7 @@ class ExtractorClassifier: responses = ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata)) return responses - def __call__(self, obj, **kwargs) -> Iterable[ImageMetadataPair]: + def __call__(self, obj, **kwargs) -> Iterable[dict]: image_metadata_pairs = self.extractor(obj, **kwargs) batches = chunk_iterable(image_metadata_pairs, chunk_size=16) predictions = chain.from_iterable(map(self.__process_batch, batches)) diff --git a/test/unit_tests/box_validation_test.py b/test/unit_tests/box_validation_test.py new file mode 100644 index 0000000..c467b08 --- /dev/null +++ b/test/unit_tests/box_validation_test.py @@ -0,0 +1,29 @@ +import pytest + +from image_prediction.exceptions import InvalidBox +from image_prediction.info import Info +from image_prediction.stitching.utils import validate_box_size, validate_box_coords + + +def test_validate_fail_too_short(): + box = {Info.WIDTH: 1, Info.HEIGHT: 0} + with pytest.raises(InvalidBox): + validate_box_size(box) + + +def test_validate_fail_too_thin(): + box = {Info.WIDTH: 0, Info.HEIGHT: 1} + with pytest.raises(InvalidBox): + validate_box_size(box) + + +def test_validate_fail_xs_width_mismatch(): + box = {Info.WIDTH: 2, Info.HEIGHT: 4, Info.X1: 0, Info.Y1: 0, Info.X2: 1, Info.Y2: 4} + with pytest.raises(InvalidBox): + validate_box_coords(box) + + +def test_validate_fail_ys_width_mismatch(): + box = {Info.WIDTH: 2, Info.HEIGHT: 3, Info.X1: 0, Info.Y1: 0, Info.X2: 2, Info.Y2: 4} + with pytest.raises(InvalidBox): + validate_box_coords(box)