2022-04-12 18:44:04 +02:00

175 lines
5.0 KiB
Python

import io
from functools import partial, lru_cache
from itertools import chain, starmap, filterfalse, repeat
from operator import itemgetter, truth
import fitz
from PIL import Image
from funcy import rcompose, compose, curry, merge, zipdict
from tqdm import tqdm
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
from image_prediction.info import Info
from image_prediction.stitching.stitching import stitch_pairs
from image_prediction.stitching.utils import validate_box_coords, validate_box_size
class ParsablePDFImageExtractor(ImageExtractor):
def __init__(self, verbose=False, tolerance=0):
"""
Args:
verbose: Whether to show progressbar
tolerance: The tolerance in pixels for the distance images beyond which they will not be stitched together
"""
self.doc: fitz.fitz.Document = None
self.verbose = verbose
self.tolerance = tolerance
def extract(self, pdf: bytes, page_range: range = None):
self.doc = fitz.Document(stream=pdf)
pages = extract_pages(self.doc, page_range) if page_range else self.doc
image_metadata_pairs = chain.from_iterable(
map(
self.__process_images_on_page,
tqdm(pages, desc="Extracting", disable=not self.verbose, total=len(page_range) if page_range else None),
)
)
yield from image_metadata_pairs
def __process_images_on_page(self, page: fitz.fitz.Page):
images = get_images_on_page(self.doc, page)
metadata = get_metadata_for_images_on_page(self.doc, page)
get_image_infos.cache_clear()
load_image_handle_from_xref.cache_clear()
image_metadata_pairs = starmap(
ImageMetadataPair, filter(compose(all, curry(map)(truth)), zip(images, metadata))
)
image_metadata_pairs = stitch_pairs(list(image_metadata_pairs), tolerance=self.tolerance)
yield from image_metadata_pairs
def extract_pages(doc, page_range):
page_range = range(page_range.start + 1, page_range.stop + 1)
pages = map(doc.load_page, page_range)
return pages
def get_images_on_page(doc, page: fitz.Page):
image_infos = get_image_infos(page)
xrefs = map(itemgetter("xref"), image_infos)
images = map(partial(xref_to_image, doc), xrefs)
return images
def get_metadata_for_images_on_page(doc, page: fitz.Page):
image_infos = get_image_infos(page)
metadata = map(get_image_metadata, image_infos)
metadata = validate_coords_and_passthrough(metadata)
metadata = filterfalse(tiny, metadata)
metadata = validate_size_and_passthrough(metadata)
metadata = map(partial(merge, get_page_metadata(page)), metadata)
xrefs = map(itemgetter("xref"), image_infos)
alpha = map(partial(has_alpha_channel, doc), xrefs)
alpha = ({Info.ALPHA: a} for a in alpha)
metadata = starmap(merge, zip(alpha, metadata))
yield from metadata
def validate_coords_and_passthrough(metadata):
yield from map(validate_box_coords, metadata)
def validate_size_and_passthrough(metadata):
yield from map(validate_box_size, metadata)
# def load_image_from_xref(doc, xref):
#
# maybe_image = doc.extract_image(xref)
# if maybe_image:
# smask = doc.extract_image(maybe_image["smask"])
# pix1 = fitz.Pixmap(maybe_image) # (1) pixmap of image w/o alpha
# mask = fitz.Pixmap(doc.extract_image(smask)["image"]) # (2) mask pixmap
# pix = fitz.Pixmap(pix1, mask) # (3) copy of pix1, image mask added
# im = Image.open(io.BytesIO(pix.tobytes()))
# else:
# im = None
#
# import IPython
# IPython.embed()
# return im
@lru_cache(maxsize=None)
def load_image_handle_from_xref(doc, xref):
return doc.extract_image(xref)
def has_alpha_channel(doc, xref):
maybe_image = load_image_handle_from_xref(doc, xref)
if maybe_image:
maybe_smask = doc.extract_image(maybe_image["smask"])
return maybe_smask is not None
else:
return False
def xref_to_image(doc, xref):
maybe_image = load_image_handle_from_xref(doc, xref)
return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None
@lru_cache(maxsize=None)
def get_image_infos(page: fitz.Page):
return page.get_image_info(xrefs=True)
def get_image_metadata(image_info):
# import IPython
# IPython.embed()
# smask = doc.extract_image(maybe_image["smask"])
x1, y1, x2, y2 = map(rounder, image_info["bbox"])
width = abs(x2 - x1)
height = abs(y2 - y1)
return {
Info.WIDTH: width,
Info.HEIGHT: height,
Info.X1: x1,
Info.X2: x2,
Info.Y1: y1,
Info.Y2: y2,
}
def tiny(metadata):
return metadata[Info.WIDTH] * metadata[Info.HEIGHT] <= 4
def get_page_metadata(page):
page_width, page_height = map(rounder, page.mediabox_size)
return {
Info.PAGE_WIDTH: page_width,
Info.PAGE_HEIGHT: page_height,
Info.PAGE_IDX: page.number,
}
rounder = rcompose(round, int)