refactoring
This commit is contained in:
parent
0dcd389415
commit
7428aeee37
@ -28,7 +28,7 @@ class Classifier:
|
|||||||
if isinstance(batch, np.ndarray) and batch.shape[0] == 0:
|
if isinstance(batch, np.ndarray) and batch.shape[0] == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return list(self.__pipe(batch)) # TODO: list?
|
return self.__pipe(batch)
|
||||||
|
|
||||||
def __call__(self, batch: np.array) -> List[str]:
|
def __call__(self, batch: np.array) -> List[str]:
|
||||||
logger.debug("Classifier.predict")
|
logger.debug("Classifier.predict")
|
||||||
|
|||||||
@ -55,4 +55,4 @@ class Pipeline:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, pdf: bytes, page_range: range = None):
|
def __call__(self, pdf: bytes, page_range: range = None):
|
||||||
yield from tqdm(self.pipe(pdf, page_range=page_range))
|
yield from tqdm(self.pipe(pdf, page_range=page_range), desc="Processing images from document", unit=" images")
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import pytest
|
|||||||
@pytest.mark.parametrize("estimator_type", ["mock", "keras", "redai"])
|
@pytest.mark.parametrize("estimator_type", ["mock", "keras", "redai"])
|
||||||
@pytest.mark.parametrize("label_format", ["index", "probability"])
|
@pytest.mark.parametrize("label_format", ["index", "probability"])
|
||||||
def test_classifier(classifier, input_batch, expected_predictions_mapped):
|
def test_classifier(classifier, input_batch, expected_predictions_mapped):
|
||||||
predictions = classifier(input_batch)
|
predictions = list(classifier(input_batch))
|
||||||
assert predictions == expected_predictions_mapped
|
assert predictions == expected_predictions_mapped
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user