alpha channel test fix
This commit is contained in:
parent
1d88876ab1
commit
62bfedfea8
@ -1,11 +1,12 @@
|
||||
import io
|
||||
from functools import partial, lru_cache
|
||||
from itertools import chain, starmap, filterfalse, repeat
|
||||
from operator import itemgetter, truth
|
||||
from itertools import chain, starmap, filterfalse
|
||||
from operator import itemgetter
|
||||
from typing import List
|
||||
|
||||
import fitz
|
||||
from PIL import Image
|
||||
from funcy import rcompose, compose, curry, merge, zipdict
|
||||
from funcy import rcompose, merge
|
||||
from tqdm import tqdm
|
||||
|
||||
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
|
||||
@ -43,12 +44,9 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
||||
def __process_images_on_page(self, page: fitz.fitz.Page):
|
||||
images = get_images_on_page(self.doc, page)
|
||||
metadata = get_metadata_for_images_on_page(self.doc, page)
|
||||
get_image_infos.cache_clear()
|
||||
load_image_handle_from_xref.cache_clear()
|
||||
clear_caches()
|
||||
|
||||
image_metadata_pairs = starmap(
|
||||
ImageMetadataPair, filter(compose(all, curry(map)(truth)), zip(images, metadata))
|
||||
)
|
||||
image_metadata_pairs = starmap(ImageMetadataPair, filter(all, zip(images, metadata)))
|
||||
image_metadata_pairs = stitch_pairs(list(image_metadata_pairs), tolerance=self.tolerance)
|
||||
|
||||
yield from image_metadata_pairs
|
||||
@ -61,6 +59,7 @@ def extract_pages(doc, page_range):
|
||||
return pages
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_images_on_page(doc, page: fitz.Page):
|
||||
image_infos = get_image_infos(page)
|
||||
xrefs = map(itemgetter("xref"), image_infos)
|
||||
@ -70,9 +69,8 @@ def get_images_on_page(doc, page: fitz.Page):
|
||||
|
||||
|
||||
def get_metadata_for_images_on_page(doc, page: fitz.Page):
|
||||
image_infos = get_image_infos(page)
|
||||
|
||||
metadata = map(get_image_metadata, image_infos)
|
||||
metadata = map(get_image_metadata, get_image_infos(page))
|
||||
metadata = validate_coords_and_passthrough(metadata)
|
||||
|
||||
metadata = filterfalse(tiny, metadata)
|
||||
@ -80,14 +78,20 @@ def get_metadata_for_images_on_page(doc, page: fitz.Page):
|
||||
|
||||
metadata = map(partial(merge, get_page_metadata(page)), metadata)
|
||||
|
||||
xrefs = map(itemgetter("xref"), image_infos)
|
||||
xrefs = map(itemgetter("xref"), get_image_infos(page))
|
||||
alpha = map(partial(has_alpha_channel, doc), xrefs)
|
||||
alpha = ({Info.ALPHA: a} for a in alpha)
|
||||
metadata = starmap(merge, zip(alpha, metadata))
|
||||
metadata = list(starmap(merge, zip(alpha, metadata)))
|
||||
|
||||
yield from metadata
|
||||
|
||||
|
||||
def clear_caches():
|
||||
get_image_infos.cache_clear()
|
||||
load_image_handle_from_xref.cache_clear()
|
||||
get_images_on_page.cache_clear()
|
||||
|
||||
|
||||
def validate_coords_and_passthrough(metadata):
|
||||
yield from map(validate_box_coords, metadata)
|
||||
|
||||
@ -96,23 +100,6 @@ def validate_size_and_passthrough(metadata):
|
||||
yield from map(validate_box_size, metadata)
|
||||
|
||||
|
||||
# def load_image_from_xref(doc, xref):
|
||||
#
|
||||
# maybe_image = doc.extract_image(xref)
|
||||
# if maybe_image:
|
||||
# smask = doc.extract_image(maybe_image["smask"])
|
||||
# pix1 = fitz.Pixmap(maybe_image) # (1) pixmap of image w/o alpha
|
||||
# mask = fitz.Pixmap(doc.extract_image(smask)["image"]) # (2) mask pixmap
|
||||
# pix = fitz.Pixmap(pix1, mask) # (3) copy of pix1, image mask added
|
||||
# im = Image.open(io.BytesIO(pix.tobytes()))
|
||||
# else:
|
||||
# im = None
|
||||
#
|
||||
# import IPython
|
||||
# IPython.embed()
|
||||
# return im
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def load_image_handle_from_xref(doc, xref):
|
||||
return doc.extract_image(xref)
|
||||
@ -133,15 +120,12 @@ def xref_to_image(doc, xref):
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_image_infos(page: fitz.Page):
|
||||
def get_image_infos(page: fitz.Page) -> List[dict]:
|
||||
return page.get_image_info(xrefs=True)
|
||||
|
||||
|
||||
def get_image_metadata(image_info):
|
||||
|
||||
# import IPython
|
||||
# IPython.embed()
|
||||
# smask = doc.extract_image(maybe_image["smask"])
|
||||
x1, y1, x2, y2 = map(rounder, image_info["bbox"])
|
||||
|
||||
width = abs(x2 - x1)
|
||||
|
||||
@ -186,14 +186,15 @@ def batch_size(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def input_size(alpha, __input_size):
|
||||
w, h, d = __input_size
|
||||
return w, h, d + alpha
|
||||
|
||||
|
||||
@pytest.fixture(params=[False])
|
||||
def input_size(request, __input_size):
|
||||
alpha = request.param
|
||||
print(alpha)
|
||||
if alpha:
|
||||
w, h, d = __input_size
|
||||
__input_size = w, h, d + 1
|
||||
return __input_size
|
||||
def alpha(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=[{"width": 10, "height": 15, "depth": 3}, {"width": 150, "height": 100, "depth": 3}])
|
||||
@ -301,7 +302,7 @@ def metadata(images, info_label_map):
|
||||
info_label_map.X2: x2,
|
||||
info_label_map.Y1: y1,
|
||||
info_label_map.Y2: y2,
|
||||
info_label_map.ALPHA: image.mode == "RGBA"
|
||||
info_label_map.ALPHA: image.mode == "RGBA",
|
||||
}
|
||||
return metadata
|
||||
|
||||
|
||||
@ -17,14 +17,12 @@ def test_image_extractor_mock(image_extractor, images):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("extractor_type", ["parsable_pdf", "default"])
|
||||
@pytest.mark.parametrize(
|
||||
"input_size",
|
||||
[{"depth": 3, "width": 170, "height": 220}, {"depth": 3, "width": 170, "height": 220}],
|
||||
indirect=["input_size"],
|
||||
)
|
||||
def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, input_size):
|
||||
@pytest.mark.parametrize("input_size", [{"depth": 3, "width": 170, "height": 220}], indirect=["input_size"])
|
||||
@pytest.mark.parametrize("alpha", [False, True])
|
||||
def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, input_size, alpha):
|
||||
images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor))
|
||||
assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images))
|
||||
if not alpha:
|
||||
assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images))
|
||||
assert list(metadata_extracted) == metadata
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user