refactoring in preparationfor alpha channel info

This commit is contained in:
Matthias Bisping 2022-04-12 18:22:38 +02:00
parent f17a232009
commit bbafad5561
5 changed files with 61 additions and 11 deletions

View File

@ -1,11 +1,11 @@
import io import io
from functools import partial, lru_cache from functools import partial, lru_cache
from itertools import chain, starmap, filterfalse from itertools import chain, starmap, filterfalse, repeat
from operator import itemgetter, truth from operator import itemgetter, truth
import fitz import fitz
from PIL import Image from PIL import Image
from funcy import rcompose, compose, curry, merge from funcy import rcompose, compose, curry, merge, zipdict
from tqdm import tqdm from tqdm import tqdm
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
@ -44,6 +44,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
images = get_images_on_page(self.doc, page) images = get_images_on_page(self.doc, page)
metadata = get_metadata_for_images_on_page(page) metadata = get_metadata_for_images_on_page(page)
get_image_infos.cache_clear() get_image_infos.cache_clear()
load_image_handle_from_xref.cache_clear()
image_metadata_pairs = starmap( image_metadata_pairs = starmap(
ImageMetadataPair, filter(compose(all, curry(map)(truth)), zip(images, metadata)) ImageMetadataPair, filter(compose(all, curry(map)(truth)), zip(images, metadata))
@ -63,7 +64,7 @@ def extract_pages(doc, page_range):
def get_images_on_page(doc, page: fitz.Page): def get_images_on_page(doc, page: fitz.Page):
image_infos = get_image_infos(page) image_infos = get_image_infos(page)
xrefs = map(itemgetter("xref"), image_infos) xrefs = map(itemgetter("xref"), image_infos)
images = map(partial(load_image_from_xref, doc), xrefs) images = map(partial(xref_to_image, doc), xrefs)
return images return images
@ -76,6 +77,11 @@ def get_metadata_for_images_on_page(page: fitz.Page):
metadata = validate_size_and_passthrough(metadata) metadata = validate_size_and_passthrough(metadata)
metadata = map(partial(merge, get_page_metadata(page)), metadata) metadata = map(partial(merge, get_page_metadata(page)), metadata)
# xrefs = map(itemgetter("xref"), image_infos)
# alpha = map(has_alpha_channel, xrefs)
# alpha = zipdict(repeat(Info.ALPHA), alpha)
# metadata = starmap(merge, zip(alpha, metadata))
yield from metadata yield from metadata
@ -87,8 +93,39 @@ def validate_size_and_passthrough(metadata):
yield from map(validate_box_size, metadata) yield from map(validate_box_size, metadata)
def load_image_from_xref(doc, xref): # def load_image_from_xref(doc, xref):
maybe_image = doc.extract_image(xref) #
# maybe_image = doc.extract_image(xref)
# if maybe_image:
# smask = doc.extract_image(maybe_image["smask"])
# pix1 = fitz.Pixmap(maybe_image) # (1) pixmap of image w/o alpha
# mask = fitz.Pixmap(doc.extract_image(smask)["image"]) # (2) mask pixmap
# pix = fitz.Pixmap(pix1, mask) # (3) copy of pix1, image mask added
# im = Image.open(io.BytesIO(pix.tobytes()))
# else:
# im = None
#
# import IPython
# IPython.embed()
# return im
@lru_cache(maxsize=None)
def load_image_handle_from_xref(doc, xref):
return doc.extract_image(xref)
def has_alpha_channel(doc, xref):
maybe_image = load_image_handle_from_xref(doc, xref)
if maybe_image:
maybe_smask = doc.extract_image(maybe_image["smask"])
return maybe_smask is not None
else:
return False
def xref_to_image(doc, xref):
maybe_image = load_image_handle_from_xref(doc, xref)
return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None
@ -98,6 +135,10 @@ def get_image_infos(page: fitz.Page):
def get_image_metadata(image_info): def get_image_metadata(image_info):
# import IPython
# IPython.embed()
# smask = doc.extract_image(maybe_image["smask"])
x1, y1, x2, y2 = map(rounder, image_info["bbox"]) x1, y1, x2, y2 = map(rounder, image_info["bbox"])
width = abs(x2 - x1) width = abs(x2 - x1)

View File

@ -11,3 +11,4 @@ class Info(Enum):
X2 = "x2" X2 = "x2"
Y1 = "y1" Y1 = "y1"
Y2 = "y2" Y2 = "y2"
# ALPHA = "alpha"

View File

@ -35,7 +35,7 @@ def process_pdf(pipeline, pdf_path, page_range=None):
def main(args): def main(args):
pipeline = load_pipeline(verbose=True, tolerance=3) pipeline = load_pipeline(verbose=False, tolerance=3)
if os.path.isfile(args.input): if os.path.isfile(args.input):
pdf_paths = [args.input] pdf_paths = [args.input]

View File

@ -194,12 +194,16 @@ def input_size(request):
def array_to_image(array): def array_to_image(array):
assert np.all(array <= 1) assert np.all(array <= 1)
assert np.all(array >= 0) assert np.all(array >= 0)
return Image.fromarray(np.uint8(array * 255), mode="RGB")
if array.shape[-1] == 3:
mode = "RGB"
elif array.shape[-1] == 4:
mode = "RGBA"
else:
raise ValueError(f"Unexpected number of channels {array.shape[-1]}. Expected 3 or 4.")
@pytest.fixture(params=[{"width": 10, "height": 15, "depth": 3}, {"width": 150, "height": 100, "depth": 3}]) # noinspection PyTypeChecker
def input_size(request): return Image.fromarray(np.uint8(array * 255), mode=mode)
return itemgetter("width", "height", "depth")(request.param)
@pytest.fixture @pytest.fixture

View File

@ -17,7 +17,11 @@ def test_image_extractor_mock(image_extractor, images):
@pytest.mark.parametrize("extractor_type", ["parsable_pdf", "default"]) @pytest.mark.parametrize("extractor_type", ["parsable_pdf", "default"])
@pytest.mark.parametrize("input_size", [{"depth": 3, "width": 170, "height": 220}], indirect=["input_size"]) @pytest.mark.parametrize(
"input_size",
[{"depth": 3, "width": 170, "height": 220}, {"depth": 3, "width": 170, "height": 220}],
indirect=["input_size"],
)
def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, input_size): def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, input_size):
images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor)) images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor))
assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images)) assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images))