refactoring in preparationfor alpha channel info
This commit is contained in:
parent
f17a232009
commit
bbafad5561
@ -1,11 +1,11 @@
|
|||||||
import io
|
import io
|
||||||
from functools import partial, lru_cache
|
from functools import partial, lru_cache
|
||||||
from itertools import chain, starmap, filterfalse
|
from itertools import chain, starmap, filterfalse, repeat
|
||||||
from operator import itemgetter, truth
|
from operator import itemgetter, truth
|
||||||
|
|
||||||
import fitz
|
import fitz
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from funcy import rcompose, compose, curry, merge
|
from funcy import rcompose, compose, curry, merge, zipdict
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
|
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
|
||||||
@ -44,6 +44,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
|||||||
images = get_images_on_page(self.doc, page)
|
images = get_images_on_page(self.doc, page)
|
||||||
metadata = get_metadata_for_images_on_page(page)
|
metadata = get_metadata_for_images_on_page(page)
|
||||||
get_image_infos.cache_clear()
|
get_image_infos.cache_clear()
|
||||||
|
load_image_handle_from_xref.cache_clear()
|
||||||
|
|
||||||
image_metadata_pairs = starmap(
|
image_metadata_pairs = starmap(
|
||||||
ImageMetadataPair, filter(compose(all, curry(map)(truth)), zip(images, metadata))
|
ImageMetadataPair, filter(compose(all, curry(map)(truth)), zip(images, metadata))
|
||||||
@ -63,7 +64,7 @@ def extract_pages(doc, page_range):
|
|||||||
def get_images_on_page(doc, page: fitz.Page):
|
def get_images_on_page(doc, page: fitz.Page):
|
||||||
image_infos = get_image_infos(page)
|
image_infos = get_image_infos(page)
|
||||||
xrefs = map(itemgetter("xref"), image_infos)
|
xrefs = map(itemgetter("xref"), image_infos)
|
||||||
images = map(partial(load_image_from_xref, doc), xrefs)
|
images = map(partial(xref_to_image, doc), xrefs)
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
@ -76,6 +77,11 @@ def get_metadata_for_images_on_page(page: fitz.Page):
|
|||||||
metadata = validate_size_and_passthrough(metadata)
|
metadata = validate_size_and_passthrough(metadata)
|
||||||
metadata = map(partial(merge, get_page_metadata(page)), metadata)
|
metadata = map(partial(merge, get_page_metadata(page)), metadata)
|
||||||
|
|
||||||
|
# xrefs = map(itemgetter("xref"), image_infos)
|
||||||
|
# alpha = map(has_alpha_channel, xrefs)
|
||||||
|
# alpha = zipdict(repeat(Info.ALPHA), alpha)
|
||||||
|
# metadata = starmap(merge, zip(alpha, metadata))
|
||||||
|
|
||||||
yield from metadata
|
yield from metadata
|
||||||
|
|
||||||
|
|
||||||
@ -87,8 +93,39 @@ def validate_size_and_passthrough(metadata):
|
|||||||
yield from map(validate_box_size, metadata)
|
yield from map(validate_box_size, metadata)
|
||||||
|
|
||||||
|
|
||||||
def load_image_from_xref(doc, xref):
|
# def load_image_from_xref(doc, xref):
|
||||||
maybe_image = doc.extract_image(xref)
|
#
|
||||||
|
# maybe_image = doc.extract_image(xref)
|
||||||
|
# if maybe_image:
|
||||||
|
# smask = doc.extract_image(maybe_image["smask"])
|
||||||
|
# pix1 = fitz.Pixmap(maybe_image) # (1) pixmap of image w/o alpha
|
||||||
|
# mask = fitz.Pixmap(doc.extract_image(smask)["image"]) # (2) mask pixmap
|
||||||
|
# pix = fitz.Pixmap(pix1, mask) # (3) copy of pix1, image mask added
|
||||||
|
# im = Image.open(io.BytesIO(pix.tobytes()))
|
||||||
|
# else:
|
||||||
|
# im = None
|
||||||
|
#
|
||||||
|
# import IPython
|
||||||
|
# IPython.embed()
|
||||||
|
# return im
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def load_image_handle_from_xref(doc, xref):
|
||||||
|
return doc.extract_image(xref)
|
||||||
|
|
||||||
|
|
||||||
|
def has_alpha_channel(doc, xref):
|
||||||
|
maybe_image = load_image_handle_from_xref(doc, xref)
|
||||||
|
if maybe_image:
|
||||||
|
maybe_smask = doc.extract_image(maybe_image["smask"])
|
||||||
|
return maybe_smask is not None
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def xref_to_image(doc, xref):
|
||||||
|
maybe_image = load_image_handle_from_xref(doc, xref)
|
||||||
return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None
|
return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None
|
||||||
|
|
||||||
|
|
||||||
@ -98,6 +135,10 @@ def get_image_infos(page: fitz.Page):
|
|||||||
|
|
||||||
|
|
||||||
def get_image_metadata(image_info):
|
def get_image_metadata(image_info):
|
||||||
|
|
||||||
|
# import IPython
|
||||||
|
# IPython.embed()
|
||||||
|
# smask = doc.extract_image(maybe_image["smask"])
|
||||||
x1, y1, x2, y2 = map(rounder, image_info["bbox"])
|
x1, y1, x2, y2 = map(rounder, image_info["bbox"])
|
||||||
|
|
||||||
width = abs(x2 - x1)
|
width = abs(x2 - x1)
|
||||||
|
|||||||
@ -11,3 +11,4 @@ class Info(Enum):
|
|||||||
X2 = "x2"
|
X2 = "x2"
|
||||||
Y1 = "y1"
|
Y1 = "y1"
|
||||||
Y2 = "y2"
|
Y2 = "y2"
|
||||||
|
# ALPHA = "alpha"
|
||||||
|
|||||||
@ -35,7 +35,7 @@ def process_pdf(pipeline, pdf_path, page_range=None):
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
pipeline = load_pipeline(verbose=True, tolerance=3)
|
pipeline = load_pipeline(verbose=False, tolerance=3)
|
||||||
|
|
||||||
if os.path.isfile(args.input):
|
if os.path.isfile(args.input):
|
||||||
pdf_paths = [args.input]
|
pdf_paths = [args.input]
|
||||||
|
|||||||
@ -194,12 +194,16 @@ def input_size(request):
|
|||||||
def array_to_image(array):
|
def array_to_image(array):
|
||||||
assert np.all(array <= 1)
|
assert np.all(array <= 1)
|
||||||
assert np.all(array >= 0)
|
assert np.all(array >= 0)
|
||||||
return Image.fromarray(np.uint8(array * 255), mode="RGB")
|
|
||||||
|
|
||||||
|
if array.shape[-1] == 3:
|
||||||
|
mode = "RGB"
|
||||||
|
elif array.shape[-1] == 4:
|
||||||
|
mode = "RGBA"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected number of channels {array.shape[-1]}. Expected 3 or 4.")
|
||||||
|
|
||||||
@pytest.fixture(params=[{"width": 10, "height": 15, "depth": 3}, {"width": 150, "height": 100, "depth": 3}])
|
# noinspection PyTypeChecker
|
||||||
def input_size(request):
|
return Image.fromarray(np.uint8(array * 255), mode=mode)
|
||||||
return itemgetter("width", "height", "depth")(request.param)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@ -17,7 +17,11 @@ def test_image_extractor_mock(image_extractor, images):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("extractor_type", ["parsable_pdf", "default"])
|
@pytest.mark.parametrize("extractor_type", ["parsable_pdf", "default"])
|
||||||
@pytest.mark.parametrize("input_size", [{"depth": 3, "width": 170, "height": 220}], indirect=["input_size"])
|
@pytest.mark.parametrize(
|
||||||
|
"input_size",
|
||||||
|
[{"depth": 3, "width": 170, "height": 220}, {"depth": 3, "width": 170, "height": 220}],
|
||||||
|
indirect=["input_size"],
|
||||||
|
)
|
||||||
def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, input_size):
|
def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, input_size):
|
||||||
images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor))
|
images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor))
|
||||||
assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images))
|
assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user