alpha channel test fix

This commit is contained in:
Matthias Bisping 2022-04-13 12:06:55 +02:00
parent 1d88876ab1
commit 62bfedfea8
3 changed files with 31 additions and 48 deletions

View File

@ -1,11 +1,12 @@
import io import io
from functools import partial, lru_cache from functools import partial, lru_cache
from itertools import chain, starmap, filterfalse, repeat from itertools import chain, starmap, filterfalse
from operator import itemgetter, truth from operator import itemgetter
from typing import List
import fitz import fitz
from PIL import Image from PIL import Image
from funcy import rcompose, compose, curry, merge, zipdict from funcy import rcompose, merge
from tqdm import tqdm from tqdm import tqdm
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
@ -43,12 +44,9 @@ class ParsablePDFImageExtractor(ImageExtractor):
def __process_images_on_page(self, page: fitz.fitz.Page): def __process_images_on_page(self, page: fitz.fitz.Page):
images = get_images_on_page(self.doc, page) images = get_images_on_page(self.doc, page)
metadata = get_metadata_for_images_on_page(self.doc, page) metadata = get_metadata_for_images_on_page(self.doc, page)
get_image_infos.cache_clear() clear_caches()
load_image_handle_from_xref.cache_clear()
image_metadata_pairs = starmap( image_metadata_pairs = starmap(ImageMetadataPair, filter(all, zip(images, metadata)))
ImageMetadataPair, filter(compose(all, curry(map)(truth)), zip(images, metadata))
)
image_metadata_pairs = stitch_pairs(list(image_metadata_pairs), tolerance=self.tolerance) image_metadata_pairs = stitch_pairs(list(image_metadata_pairs), tolerance=self.tolerance)
yield from image_metadata_pairs yield from image_metadata_pairs
@ -61,6 +59,7 @@ def extract_pages(doc, page_range):
return pages return pages
@lru_cache(maxsize=None)
def get_images_on_page(doc, page: fitz.Page): def get_images_on_page(doc, page: fitz.Page):
image_infos = get_image_infos(page) image_infos = get_image_infos(page)
xrefs = map(itemgetter("xref"), image_infos) xrefs = map(itemgetter("xref"), image_infos)
@ -70,9 +69,8 @@ def get_images_on_page(doc, page: fitz.Page):
def get_metadata_for_images_on_page(doc, page: fitz.Page): def get_metadata_for_images_on_page(doc, page: fitz.Page):
image_infos = get_image_infos(page)
metadata = map(get_image_metadata, image_infos) metadata = map(get_image_metadata, get_image_infos(page))
metadata = validate_coords_and_passthrough(metadata) metadata = validate_coords_and_passthrough(metadata)
metadata = filterfalse(tiny, metadata) metadata = filterfalse(tiny, metadata)
@ -80,14 +78,20 @@ def get_metadata_for_images_on_page(doc, page: fitz.Page):
metadata = map(partial(merge, get_page_metadata(page)), metadata) metadata = map(partial(merge, get_page_metadata(page)), metadata)
xrefs = map(itemgetter("xref"), image_infos) xrefs = map(itemgetter("xref"), get_image_infos(page))
alpha = map(partial(has_alpha_channel, doc), xrefs) alpha = map(partial(has_alpha_channel, doc), xrefs)
alpha = ({Info.ALPHA: a} for a in alpha) alpha = ({Info.ALPHA: a} for a in alpha)
metadata = starmap(merge, zip(alpha, metadata)) metadata = list(starmap(merge, zip(alpha, metadata)))
yield from metadata yield from metadata
def clear_caches():
get_image_infos.cache_clear()
load_image_handle_from_xref.cache_clear()
get_images_on_page.cache_clear()
def validate_coords_and_passthrough(metadata): def validate_coords_and_passthrough(metadata):
yield from map(validate_box_coords, metadata) yield from map(validate_box_coords, metadata)
@ -96,23 +100,6 @@ def validate_size_and_passthrough(metadata):
yield from map(validate_box_size, metadata) yield from map(validate_box_size, metadata)
# def load_image_from_xref(doc, xref):
#
# maybe_image = doc.extract_image(xref)
# if maybe_image:
# smask = doc.extract_image(maybe_image["smask"])
# pix1 = fitz.Pixmap(maybe_image) # (1) pixmap of image w/o alpha
# mask = fitz.Pixmap(doc.extract_image(smask)["image"]) # (2) mask pixmap
# pix = fitz.Pixmap(pix1, mask) # (3) copy of pix1, image mask added
# im = Image.open(io.BytesIO(pix.tobytes()))
# else:
# im = None
#
# import IPython
# IPython.embed()
# return im
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def load_image_handle_from_xref(doc, xref): def load_image_handle_from_xref(doc, xref):
return doc.extract_image(xref) return doc.extract_image(xref)
@ -133,15 +120,12 @@ def xref_to_image(doc, xref):
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_image_infos(page: fitz.Page): def get_image_infos(page: fitz.Page) -> List[dict]:
return page.get_image_info(xrefs=True) return page.get_image_info(xrefs=True)
def get_image_metadata(image_info): def get_image_metadata(image_info):
# import IPython
# IPython.embed()
# smask = doc.extract_image(maybe_image["smask"])
x1, y1, x2, y2 = map(rounder, image_info["bbox"]) x1, y1, x2, y2 = map(rounder, image_info["bbox"])
width = abs(x2 - x1) width = abs(x2 - x1)

View File

@ -186,14 +186,15 @@ def batch_size(request):
return request.param return request.param
@pytest.fixture
def input_size(alpha, __input_size):
w, h, d = __input_size
return w, h, d + alpha
@pytest.fixture(params=[False]) @pytest.fixture(params=[False])
def input_size(request, __input_size): def alpha(request):
alpha = request.param return request.param
print(alpha)
if alpha:
w, h, d = __input_size
__input_size = w, h, d + 1
return __input_size
@pytest.fixture(params=[{"width": 10, "height": 15, "depth": 3}, {"width": 150, "height": 100, "depth": 3}]) @pytest.fixture(params=[{"width": 10, "height": 15, "depth": 3}, {"width": 150, "height": 100, "depth": 3}])
@ -301,7 +302,7 @@ def metadata(images, info_label_map):
info_label_map.X2: x2, info_label_map.X2: x2,
info_label_map.Y1: y1, info_label_map.Y1: y1,
info_label_map.Y2: y2, info_label_map.Y2: y2,
info_label_map.ALPHA: image.mode == "RGBA" info_label_map.ALPHA: image.mode == "RGBA",
} }
return metadata return metadata

View File

@ -17,14 +17,12 @@ def test_image_extractor_mock(image_extractor, images):
@pytest.mark.parametrize("extractor_type", ["parsable_pdf", "default"]) @pytest.mark.parametrize("extractor_type", ["parsable_pdf", "default"])
@pytest.mark.parametrize( @pytest.mark.parametrize("input_size", [{"depth": 3, "width": 170, "height": 220}], indirect=["input_size"])
"input_size", @pytest.mark.parametrize("alpha", [False, True])
[{"depth": 3, "width": 170, "height": 220}, {"depth": 3, "width": 170, "height": 220}], def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, input_size, alpha):
indirect=["input_size"],
)
def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, input_size):
images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor)) images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor))
assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images)) if not alpha:
assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images))
assert list(metadata_extracted) == metadata assert list(metadata_extracted) == metadata