import atexit import io from functools import partial, lru_cache from itertools import chain, starmap, filterfalse from operator import itemgetter from typing import List import fitz from PIL import Image from funcy import rcompose, merge, pluck, curry, compose, rpartial from tqdm import tqdm from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair from image_prediction.info import Info from image_prediction.stitching.stitching import stitch_pairs from image_prediction.stitching.utils import validate_box_coords, validate_box_size from image_prediction.utils.generic import lift class ParsablePDFImageExtractor(ImageExtractor): def __init__(self, verbose=False, tolerance=0, progress_message=None): """ Args: verbose: Whether to show progressbar tolerance: The tolerance in pixels for the distance images beyond which they will not be stitched together """ self.doc: fitz.fitz.Document = None self.verbose = verbose self.tolerance = tolerance self.progress_message = progress_message def extract(self, pdf: bytes, page_range: range = None): self.doc = fitz.Document(stream=pdf) pages = extract_pages(self.doc, page_range) if page_range else self.doc # pages = self.__maybe_show_progress(pages, page_range) image_metadata_pairs = chain.from_iterable(map(self.__process_images_on_page, pages)) yield from image_metadata_pairs # def __maybe_show_progress(self, iterable, page_range): # return self.__progressbar(page_range)(iterable) if self.verbose else iterable # # def __progressbar(self, page_range): # return rpartial(tqdm, desc=self.progress_message, total=len(page_range) if page_range else None) def __process_images_on_page(self, page: fitz.fitz.Page): images = get_images_on_page(self.doc, page) metadata = get_metadata_for_images_on_page(self.doc, page) clear_caches() image_metadata_pairs = starmap(ImageMetadataPair, filter(all, zip(images, metadata))) image_metadata_pairs = stitch_pairs(list(image_metadata_pairs), tolerance=self.tolerance) yield from image_metadata_pairs def extract_pages(doc, page_range): page_range = range(page_range.start + 1, page_range.stop + 1) pages = map(doc.load_page, page_range) yield from pages @lru_cache(maxsize=None) def get_images_on_page(doc, page: fitz.Page): image_infos = get_image_infos(page) xrefs = map(itemgetter("xref"), image_infos) images = map(partial(xref_to_image, doc), xrefs) yield from images def get_metadata_for_images_on_page(doc, page: fitz.Page): metadata = map(get_image_metadata, get_image_infos(page)) metadata = validate_coords_and_passthrough(metadata) metadata = filter_out_tiny_images(metadata) metadata = validate_size_and_passthrough(metadata) metadata = add_page_metadata(page, metadata) metadata = add_alpha_channel_info(doc, page, metadata) yield from metadata @lru_cache(maxsize=None) def get_image_infos(page: fitz.Page) -> List[dict]: return page.get_image_info(xrefs=True) @lru_cache(maxsize=None) def xref_to_image(doc, xref) -> Image: maybe_image = load_image_handle_from_xref(doc, xref) return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None def get_image_metadata(image_info): x1, y1, x2, y2 = map(rounder, image_info["bbox"]) width = abs(x2 - x1) height = abs(y2 - y1) return { Info.WIDTH: width, Info.HEIGHT: height, Info.X1: x1, Info.X2: x2, Info.Y1: y1, Info.Y2: y2, } def validate_coords_and_passthrough(metadata): yield from map(validate_box_coords, metadata) def filter_out_tiny_images(metadata): yield from filterfalse(tiny, metadata) def validate_size_and_passthrough(metadata): yield from map(validate_box_size, metadata) def add_page_metadata(page, metadata): yield from map(partial(merge, get_page_metadata(page)), metadata) def add_alpha_channel_info(doc, page, metadata): page_to_xrefs = compose(curry(pluck)("xref"), get_image_infos) xref_to_alpha = partial(has_alpha_channel, doc) page_to_alpha_value_per_image = compose(lift(xref_to_alpha), page_to_xrefs) alpha_to_dict = compose(dict, lambda a: [(Info.ALPHA, a)]) page_to_alpha_mapping_per_image = compose(lift(alpha_to_dict), page_to_alpha_value_per_image) metadata = starmap(merge, zip(page_to_alpha_mapping_per_image(page), metadata)) yield from metadata @lru_cache(maxsize=None) def load_image_handle_from_xref(doc, xref): return doc.extract_image(xref) rounder = rcompose(round, int) def get_page_metadata(page): page_width, page_height = map(rounder, page.mediabox_size) return { Info.PAGE_WIDTH: page_width, Info.PAGE_HEIGHT: page_height, Info.PAGE_IDX: page.number, } def has_alpha_channel(doc, xref): maybe_image = load_image_handle_from_xref(doc, xref) maybe_smask = maybe_image["smask"] if maybe_image else None if maybe_smask: return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)]) else: return bool(fitz.Pixmap(doc, xref).alpha) def tiny(metadata): return metadata[Info.WIDTH] * metadata[Info.HEIGHT] <= 4 def clear_caches(): get_image_infos.cache_clear() load_image_handle_from_xref.cache_clear() get_images_on_page.cache_clear() xref_to_image.cache_clear() atexit.register(clear_caches)