diff --git a/image_prediction/classifier/image_classifier.py b/image_prediction/classifier/image_classifier.py index 7bd21fc..0963e81 100644 --- a/image_prediction/classifier/image_classifier.py +++ b/image_prediction/classifier/image_classifier.py @@ -22,7 +22,8 @@ class ImageClassifier: def predict(self, images: Iterable[Image], batch_size=16): batches = chunk_iterable(images, chunk_size=batch_size) - return chain(*map(self.pipe, batches)) + predictions = chain.from_iterable(map(self.pipe, batches)) + return predictions def __call__(self, images: Iterable[Image], batch_size=16): - return self.predict(images, batch_size=batch_size) + yield from self.predict(images, batch_size=batch_size) diff --git a/image_prediction/extractor_classifier/extractor_classifier.py b/image_prediction/extractor_classifier/extractor_classifier.py index 43e7788..930dfb1 100644 --- a/image_prediction/extractor_classifier/extractor_classifier.py +++ b/image_prediction/extractor_classifier/extractor_classifier.py @@ -27,4 +27,5 @@ class ExtractorClassifier: def __call__(self, obj) -> Iterable[ImageMetadataPair]: image_metadata_pairs = self.extractor(obj) batches = chunk_iterable(image_metadata_pairs, chunk_size=16) - return chain(*map(self.__process_batch, batches)) + predictions = chain.from_iterable(map(self.__process_batch, batches)) + return predictions diff --git a/image_prediction/image_extractor/extractors/parsable.py b/image_prediction/image_extractor/extractors/parsable.py index 45e6ae7..e7eab63 100644 --- a/image_prediction/image_extractor/extractors/parsable.py +++ b/image_prediction/image_extractor/extractors/parsable.py @@ -9,6 +9,9 @@ from tqdm import tqdm from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair from image_prediction.info import Info +from image_prediction.utils import get_logger + +logger = get_logger() class ParsablePDFImageExtractor(ImageExtractor): @@ -47,8 +50,12 @@ class ParsablePDFImageExtractor(ImageExtractor): return starmap(ImageMetadataPair, zip(images, metadata)) def extract(self, pdf: bytes): + logger.debug("Extracting") + self.doc = fitz.Document(stream=pdf) - image_metadata_pairs = chain( - *map(self.__process_images_on_page, tqdm(self.doc, desc="Extracting", disable=not self.verbose)) + + image_metadata_pairs = chain.from_iterable( + map(self.__process_images_on_page, tqdm(self.doc, desc="Extracting", disable=not self.verbose)) ) + return image_metadata_pairs diff --git a/image_prediction/pipeline.py b/image_prediction/pipeline.py index ff9677e..af4e416 100644 --- a/image_prediction/pipeline.py +++ b/image_prediction/pipeline.py @@ -51,4 +51,4 @@ class Pipeline: self.pipe = rcompose(get_extractor_classifier(), get_formatter()) def __call__(self, pdf: bytes): - return self.pipe(pdf) + yield from self.pipe(pdf)