diff --git a/image_prediction/image_extractor/extractors/parsable.py b/image_prediction/image_extractor/extractors/parsable.py index d960eeb..b70b0d9 100644 --- a/image_prediction/image_extractor/extractors/parsable.py +++ b/image_prediction/image_extractor/extractors/parsable.py @@ -1,15 +1,40 @@ import io +from functools import partial from itertools import chain, starmap from operator import itemgetter, truth import fitz from PIL import Image -from funcy import rcompose, compose, curry +from funcy import rcompose, compose, curry, merge from tqdm import tqdm from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair from image_prediction.info import Info +rounder = rcompose(round, int) + + +def get_image_metadata(image_info): + x1, y1, x2, y2 = map(rounder, image_info["bbox"]) + width, height = itemgetter("width", "height")(image_info) + return { + Info.WIDTH: width, + Info.HEIGHT: height, + Info.X1: x1, + Info.X2: x2, + Info.Y1: y1, + Info.Y2: y2, + } + + +def get_page_metadata(page): + page_width, page_height = map(rounder, page.mediabox_size) + + return { + Info.PAGE_WIDTH: page_width, + Info.PAGE_HEIGHT: page_height, + Info.PAGE_IDX: page.number, + } class ParsablePDFImageExtractor(ImageExtractor): def __init__(self, verbose=False): @@ -24,37 +49,29 @@ class ParsablePDFImageExtractor(ImageExtractor): else: return None - def format_metadata(image_info): - x1, y1, x2, y2 = map(rounder, image_info["bbox"]) - width, height = itemgetter("width", "height")(image_info) - return { - Info.PAGE_WIDTH: page_width, - Info.PAGE_HEIGHT: page_height, - Info.PAGE_IDX: page.number, - Info.WIDTH: width, - Info.HEIGHT: height, - Info.X1: x1, - Info.X2: x2, - Info.Y1: y1, - Info.Y2: y2, - } - - rounder = rcompose(round, int) - - page_width, page_height = map(rounder, page.mediabox_size) - image_infos = page.get_image_info(xrefs=True) xrefs = map(itemgetter("xref"), image_infos) images = map(load_image_from_xref, xrefs) - metadata = map(format_metadata, image_infos) + metadata = map(get_image_metadata, image_infos) + metadata = map(partial(merge, get_page_metadata(page)), metadata) return starmap(ImageMetadataPair, filter(compose(all, curry(map)(truth)), zip(images, metadata))) - def extract(self, pdf: bytes): + def extract(self, pdf: bytes, page_range: range = None): self.doc = fitz.Document(stream=pdf) + if page_range: + page_range = range(page_range.start + 1, page_range.stop + 1) + doc = fitz.Document(stream=pdf) + pages = map(doc.load_page, page_range) + else: + pages = self.doc + image_metadata_pairs = chain.from_iterable( - map(self.__process_images_on_page, tqdm(self.doc, desc="Extracting", disable=not self.verbose)) + map( + self.__process_images_on_page, + tqdm(pages, desc="Extracting", disable=not self.verbose, total=len(page_range) if page_range else None) + ) ) return image_metadata_pairs