alpha channel test fix

This commit is contained in:
Matthias Bisping 2022-04-13 12:06:55 +02:00
parent 1d88876ab1
commit 62bfedfea8
3 changed files with 31 additions and 48 deletions

View File

@ -1,11 +1,12 @@
import io
from functools import partial, lru_cache
from itertools import chain, starmap, filterfalse, repeat
from operator import itemgetter, truth
from itertools import chain, starmap, filterfalse
from operator import itemgetter
from typing import List
import fitz
from PIL import Image
from funcy import rcompose, compose, curry, merge, zipdict
from funcy import rcompose, merge
from tqdm import tqdm
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
@ -43,12 +44,9 @@ class ParsablePDFImageExtractor(ImageExtractor):
def __process_images_on_page(self, page: fitz.fitz.Page):
images = get_images_on_page(self.doc, page)
metadata = get_metadata_for_images_on_page(self.doc, page)
get_image_infos.cache_clear()
load_image_handle_from_xref.cache_clear()
clear_caches()
image_metadata_pairs = starmap(
ImageMetadataPair, filter(compose(all, curry(map)(truth)), zip(images, metadata))
)
image_metadata_pairs = starmap(ImageMetadataPair, filter(all, zip(images, metadata)))
image_metadata_pairs = stitch_pairs(list(image_metadata_pairs), tolerance=self.tolerance)
yield from image_metadata_pairs
@ -61,6 +59,7 @@ def extract_pages(doc, page_range):
return pages
@lru_cache(maxsize=None)
def get_images_on_page(doc, page: fitz.Page):
image_infos = get_image_infos(page)
xrefs = map(itemgetter("xref"), image_infos)
@ -70,9 +69,8 @@ def get_images_on_page(doc, page: fitz.Page):
def get_metadata_for_images_on_page(doc, page: fitz.Page):
image_infos = get_image_infos(page)
metadata = map(get_image_metadata, image_infos)
metadata = map(get_image_metadata, get_image_infos(page))
metadata = validate_coords_and_passthrough(metadata)
metadata = filterfalse(tiny, metadata)
@ -80,14 +78,20 @@ def get_metadata_for_images_on_page(doc, page: fitz.Page):
metadata = map(partial(merge, get_page_metadata(page)), metadata)
xrefs = map(itemgetter("xref"), image_infos)
xrefs = map(itemgetter("xref"), get_image_infos(page))
alpha = map(partial(has_alpha_channel, doc), xrefs)
alpha = ({Info.ALPHA: a} for a in alpha)
metadata = starmap(merge, zip(alpha, metadata))
metadata = list(starmap(merge, zip(alpha, metadata)))
yield from metadata
def clear_caches():
get_image_infos.cache_clear()
load_image_handle_from_xref.cache_clear()
get_images_on_page.cache_clear()
def validate_coords_and_passthrough(metadata):
yield from map(validate_box_coords, metadata)
@ -96,23 +100,6 @@ def validate_size_and_passthrough(metadata):
yield from map(validate_box_size, metadata)
# def load_image_from_xref(doc, xref):
#
# maybe_image = doc.extract_image(xref)
# if maybe_image:
# smask = doc.extract_image(maybe_image["smask"])
# pix1 = fitz.Pixmap(maybe_image) # (1) pixmap of image w/o alpha
# mask = fitz.Pixmap(doc.extract_image(smask)["image"]) # (2) mask pixmap
# pix = fitz.Pixmap(pix1, mask) # (3) copy of pix1, image mask added
# im = Image.open(io.BytesIO(pix.tobytes()))
# else:
# im = None
#
# import IPython
# IPython.embed()
# return im
@lru_cache(maxsize=None)
def load_image_handle_from_xref(doc, xref):
return doc.extract_image(xref)
@ -133,15 +120,12 @@ def xref_to_image(doc, xref):
@lru_cache(maxsize=None)
def get_image_infos(page: fitz.Page):
def get_image_infos(page: fitz.Page) -> List[dict]:
return page.get_image_info(xrefs=True)
def get_image_metadata(image_info):
# import IPython
# IPython.embed()
# smask = doc.extract_image(maybe_image["smask"])
x1, y1, x2, y2 = map(rounder, image_info["bbox"])
width = abs(x2 - x1)

View File

@ -186,14 +186,15 @@ def batch_size(request):
return request.param
@pytest.fixture
def input_size(alpha, __input_size):
w, h, d = __input_size
return w, h, d + alpha
@pytest.fixture(params=[False])
def input_size(request, __input_size):
alpha = request.param
print(alpha)
if alpha:
w, h, d = __input_size
__input_size = w, h, d + 1
return __input_size
def alpha(request):
return request.param
@pytest.fixture(params=[{"width": 10, "height": 15, "depth": 3}, {"width": 150, "height": 100, "depth": 3}])
@ -301,7 +302,7 @@ def metadata(images, info_label_map):
info_label_map.X2: x2,
info_label_map.Y1: y1,
info_label_map.Y2: y2,
info_label_map.ALPHA: image.mode == "RGBA"
info_label_map.ALPHA: image.mode == "RGBA",
}
return metadata

View File

@ -17,14 +17,12 @@ def test_image_extractor_mock(image_extractor, images):
@pytest.mark.parametrize("extractor_type", ["parsable_pdf", "default"])
@pytest.mark.parametrize(
"input_size",
[{"depth": 3, "width": 170, "height": 220}, {"depth": 3, "width": 170, "height": 220}],
indirect=["input_size"],
)
def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, input_size):
@pytest.mark.parametrize("input_size", [{"depth": 3, "width": 170, "height": 220}], indirect=["input_size"])
@pytest.mark.parametrize("alpha", [False, True])
def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, input_size, alpha):
images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor))
assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images))
if not alpha:
assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images))
assert list(metadata_extracted) == metadata