fixed chaining bug that lead to greedy evaluation
This commit is contained in:
parent
81ab9a5f53
commit
45a07c620a
@ -22,7 +22,8 @@ class ImageClassifier:
|
|||||||
|
|
||||||
def predict(self, images: Iterable[Image], batch_size=16):
|
def predict(self, images: Iterable[Image], batch_size=16):
|
||||||
batches = chunk_iterable(images, chunk_size=batch_size)
|
batches = chunk_iterable(images, chunk_size=batch_size)
|
||||||
return chain(*map(self.pipe, batches))
|
predictions = chain.from_iterable(map(self.pipe, batches))
|
||||||
|
return predictions
|
||||||
|
|
||||||
def __call__(self, images: Iterable[Image], batch_size=16):
|
def __call__(self, images: Iterable[Image], batch_size=16):
|
||||||
return self.predict(images, batch_size=batch_size)
|
yield from self.predict(images, batch_size=batch_size)
|
||||||
|
|||||||
@ -27,4 +27,5 @@ class ExtractorClassifier:
|
|||||||
def __call__(self, obj) -> Iterable[ImageMetadataPair]:
|
def __call__(self, obj) -> Iterable[ImageMetadataPair]:
|
||||||
image_metadata_pairs = self.extractor(obj)
|
image_metadata_pairs = self.extractor(obj)
|
||||||
batches = chunk_iterable(image_metadata_pairs, chunk_size=16)
|
batches = chunk_iterable(image_metadata_pairs, chunk_size=16)
|
||||||
return chain(*map(self.__process_batch, batches))
|
predictions = chain.from_iterable(map(self.__process_batch, batches))
|
||||||
|
return predictions
|
||||||
|
|||||||
@ -9,6 +9,9 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
|
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
|
||||||
from image_prediction.info import Info
|
from image_prediction.info import Info
|
||||||
|
from image_prediction.utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
class ParsablePDFImageExtractor(ImageExtractor):
|
class ParsablePDFImageExtractor(ImageExtractor):
|
||||||
@ -47,8 +50,12 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
|||||||
return starmap(ImageMetadataPair, zip(images, metadata))
|
return starmap(ImageMetadataPair, zip(images, metadata))
|
||||||
|
|
||||||
def extract(self, pdf: bytes):
|
def extract(self, pdf: bytes):
|
||||||
|
logger.debug("Extracting")
|
||||||
|
|
||||||
self.doc = fitz.Document(stream=pdf)
|
self.doc = fitz.Document(stream=pdf)
|
||||||
image_metadata_pairs = chain(
|
|
||||||
*map(self.__process_images_on_page, tqdm(self.doc, desc="Extracting", disable=not self.verbose))
|
image_metadata_pairs = chain.from_iterable(
|
||||||
|
map(self.__process_images_on_page, tqdm(self.doc, desc="Extracting", disable=not self.verbose))
|
||||||
)
|
)
|
||||||
|
|
||||||
return image_metadata_pairs
|
return image_metadata_pairs
|
||||||
|
|||||||
@ -51,4 +51,4 @@ class Pipeline:
|
|||||||
self.pipe = rcompose(get_extractor_classifier(), get_formatter())
|
self.pipe = rcompose(get_extractor_classifier(), get_formatter())
|
||||||
|
|
||||||
def __call__(self, pdf: bytes):
|
def __call__(self, pdf: bytes):
|
||||||
return self.pipe(pdf)
|
yield from self.pipe(pdf)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user