refactoring

This commit is contained in:
Matthias Bisping 2022-04-25 09:45:45 +02:00
parent 0dcd389415
commit 7428aeee37
3 changed files with 3 additions and 3 deletions

View File

@ -28,7 +28,7 @@ class Classifier:
if isinstance(batch, np.ndarray) and batch.shape[0] == 0: if isinstance(batch, np.ndarray) and batch.shape[0] == 0:
return [] return []
return list(self.__pipe(batch)) # TODO: list? return self.__pipe(batch)
def __call__(self, batch: np.array) -> List[str]: def __call__(self, batch: np.array) -> List[str]:
logger.debug("Classifier.predict") logger.debug("Classifier.predict")

View File

@ -55,4 +55,4 @@ class Pipeline:
) )
def __call__(self, pdf: bytes, page_range: range = None): def __call__(self, pdf: bytes, page_range: range = None):
yield from tqdm(self.pipe(pdf, page_range=page_range)) yield from tqdm(self.pipe(pdf, page_range=page_range), desc="Processing images from document", unit=" images")

View File

@ -4,7 +4,7 @@ import pytest
@pytest.mark.parametrize("estimator_type", ["mock", "keras", "redai"]) @pytest.mark.parametrize("estimator_type", ["mock", "keras", "redai"])
@pytest.mark.parametrize("label_format", ["index", "probability"]) @pytest.mark.parametrize("label_format", ["index", "probability"])
def test_classifier(classifier, input_batch, expected_predictions_mapped): def test_classifier(classifier, input_batch, expected_predictions_mapped):
predictions = classifier(input_batch) predictions = list(classifier(input_batch))
assert predictions == expected_predictions_mapped assert predictions == expected_predictions_mapped