tests for box validation

This commit is contained in:
Matthias Bisping 2022-04-12 16:54:40 +02:00
parent 88a46ae7cd
commit f17a232009
2 changed files with 31 additions and 2 deletions

View File

@ -2,7 +2,7 @@ from itertools import chain
from typing import Iterable
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
from image_prediction.image_extractor.extractor import ImageExtractor
from image_prediction.utils.generic import chunk_iterable
@ -24,7 +24,7 @@ class ExtractorClassifier:
responses = ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata))
return responses
def __call__(self, obj, **kwargs) -> Iterable[ImageMetadataPair]:
def __call__(self, obj, **kwargs) -> Iterable[dict]:
image_metadata_pairs = self.extractor(obj, **kwargs)
batches = chunk_iterable(image_metadata_pairs, chunk_size=16)
predictions = chain.from_iterable(map(self.__process_batch, batches))

View File

@ -0,0 +1,29 @@
import pytest
from image_prediction.exceptions import InvalidBox
from image_prediction.info import Info
from image_prediction.stitching.utils import validate_box_size, validate_box_coords
def test_validate_fail_too_short():
box = {Info.WIDTH: 1, Info.HEIGHT: 0}
with pytest.raises(InvalidBox):
validate_box_size(box)
def test_validate_fail_too_thin():
box = {Info.WIDTH: 0, Info.HEIGHT: 1}
with pytest.raises(InvalidBox):
validate_box_size(box)
def test_validate_fail_xs_width_mismatch():
box = {Info.WIDTH: 2, Info.HEIGHT: 4, Info.X1: 0, Info.Y1: 0, Info.X2: 1, Info.Y2: 4}
with pytest.raises(InvalidBox):
validate_box_coords(box)
def test_validate_fail_ys_width_mismatch():
box = {Info.WIDTH: 2, Info.HEIGHT: 3, Info.X1: 0, Info.Y1: 0, Info.X2: 2, Info.Y2: 4}
with pytest.raises(InvalidBox):
validate_box_coords(box)