refactoring

This commit is contained in:
Matthias Bisping 2022-04-05 13:03:22 +02:00
parent e0885c545a
commit 4756b8c9bd

View File

@ -1,15 +1,40 @@
import io import io
from functools import partial
from itertools import chain, starmap from itertools import chain, starmap
from operator import itemgetter, truth from operator import itemgetter, truth
import fitz import fitz
from PIL import Image from PIL import Image
from funcy import rcompose, compose, curry from funcy import rcompose, compose, curry, merge
from tqdm import tqdm from tqdm import tqdm
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
from image_prediction.info import Info from image_prediction.info import Info
rounder = rcompose(round, int)
def get_image_metadata(image_info):
x1, y1, x2, y2 = map(rounder, image_info["bbox"])
width, height = itemgetter("width", "height")(image_info)
return {
Info.WIDTH: width,
Info.HEIGHT: height,
Info.X1: x1,
Info.X2: x2,
Info.Y1: y1,
Info.Y2: y2,
}
def get_page_metadata(page):
page_width, page_height = map(rounder, page.mediabox_size)
return {
Info.PAGE_WIDTH: page_width,
Info.PAGE_HEIGHT: page_height,
Info.PAGE_IDX: page.number,
}
class ParsablePDFImageExtractor(ImageExtractor): class ParsablePDFImageExtractor(ImageExtractor):
def __init__(self, verbose=False): def __init__(self, verbose=False):
@ -24,37 +49,29 @@ class ParsablePDFImageExtractor(ImageExtractor):
else: else:
return None return None
def format_metadata(image_info):
x1, y1, x2, y2 = map(rounder, image_info["bbox"])
width, height = itemgetter("width", "height")(image_info)
return {
Info.PAGE_WIDTH: page_width,
Info.PAGE_HEIGHT: page_height,
Info.PAGE_IDX: page.number,
Info.WIDTH: width,
Info.HEIGHT: height,
Info.X1: x1,
Info.X2: x2,
Info.Y1: y1,
Info.Y2: y2,
}
rounder = rcompose(round, int)
page_width, page_height = map(rounder, page.mediabox_size)
image_infos = page.get_image_info(xrefs=True) image_infos = page.get_image_info(xrefs=True)
xrefs = map(itemgetter("xref"), image_infos) xrefs = map(itemgetter("xref"), image_infos)
images = map(load_image_from_xref, xrefs) images = map(load_image_from_xref, xrefs)
metadata = map(format_metadata, image_infos) metadata = map(get_image_metadata, image_infos)
metadata = map(partial(merge, get_page_metadata(page)), metadata)
return starmap(ImageMetadataPair, filter(compose(all, curry(map)(truth)), zip(images, metadata))) return starmap(ImageMetadataPair, filter(compose(all, curry(map)(truth)), zip(images, metadata)))
def extract(self, pdf: bytes): def extract(self, pdf: bytes, page_range: range = None):
self.doc = fitz.Document(stream=pdf) self.doc = fitz.Document(stream=pdf)
if page_range:
page_range = range(page_range.start + 1, page_range.stop + 1)
doc = fitz.Document(stream=pdf)
pages = map(doc.load_page, page_range)
else:
pages = self.doc
image_metadata_pairs = chain.from_iterable( image_metadata_pairs = chain.from_iterable(
map(self.__process_images_on_page, tqdm(self.doc, desc="Extracting", disable=not self.verbose)) map(
self.__process_images_on_page,
tqdm(pages, desc="Extracting", disable=not self.verbose, total=len(page_range) if page_range else None)
)
) )
return image_metadata_pairs return image_metadata_pairs