Refactoring
- Rename - Add typehints everywhere
This commit is contained in:
parent
865e0819a1
commit
3cea4dad2d
@ -3,12 +3,12 @@ from _operator import itemgetter
|
||||
from functools import partial, lru_cache
|
||||
from itertools import chain, filterfalse
|
||||
from operator import itemgetter
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Iterable
|
||||
|
||||
import fitz
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from funcy import merge, compose, rcompose, keep
|
||||
from funcy import merge, compose, rcompose, keep, lkeep
|
||||
from pymonad.either import Right, Left, Either
|
||||
from pymonad.tools import curry, identity
|
||||
|
||||
@ -24,6 +24,11 @@ from image_prediction.utils.generic import bottom, left, right
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
Doc = fitz.fitz.Document
|
||||
Pag = fitz.fitz.Page
|
||||
Img = Image.Image
|
||||
|
||||
|
||||
class ParsablePDFImageExtractor(ImageExtractor):
|
||||
def __init__(self, verbose=False, tolerance=0):
|
||||
"""
|
||||
@ -33,7 +38,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
||||
tolerance: The tolerance in pixels for the distance between images, beyond which they will not be stitched
|
||||
together
|
||||
"""
|
||||
self.doc: Union[fitz.fitz.Document, None] = None
|
||||
self.doc: Union[Doc, None] = None
|
||||
self.verbose = verbose
|
||||
self.tolerance = tolerance
|
||||
|
||||
@ -46,36 +51,42 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
||||
|
||||
yield from image_metadata_pairs
|
||||
|
||||
def __process_images_on_page(self, page: fitz.fitz.Page):
|
||||
def __process_images_on_page(self, page: Pag):
|
||||
metadata = extract_valid_metadata(self.doc, page)
|
||||
|
||||
maybe_image_metadata_pairs = map(partial(metadatum_to_image_metadata_pair, self.doc), metadata)
|
||||
image_metadata_pairs = keep(take_right, maybe_image_metadata_pairs)
|
||||
either_image_metadata_pair_or_error_per_image = map(self.__metadatum_to_image_metadata_pair, metadata)
|
||||
valid_image_metadata_pairs = lkeep(take_good_log_bad, either_image_metadata_pair_or_error_per_image)
|
||||
valid_image_metadata_pairs_stitched = stitch_pairs(valid_image_metadata_pairs, tolerance=self.tolerance)
|
||||
|
||||
clear_caches()
|
||||
|
||||
image_metadata_pairs = stitch_pairs(list(image_metadata_pairs), tolerance=self.tolerance)
|
||||
yield from valid_image_metadata_pairs_stitched
|
||||
|
||||
yield from image_metadata_pairs
|
||||
def __metadatum_to_image_metadata_pair(self, metadatum: dict) -> Either:
|
||||
return metadatum_to_image_metadata_pair(self.doc, metadatum)
|
||||
|
||||
|
||||
def take_right(pair: Either):
|
||||
if pair.is_right():
|
||||
return pair.either(bottom, identity)
|
||||
logger.warning(f"Skipping bad image. {pair.either(format_context, bottom)}")
|
||||
def take_good_log_bad(pair: Either) -> Union[ImageMetadataPair, None]:
|
||||
return pair.either(log_error_context, identity)
|
||||
|
||||
|
||||
def format_context(context):
|
||||
def log_error_context(context: dict) -> None:
|
||||
logger.warning(f"Skipping bad image. {format_context(context)}")
|
||||
return None
|
||||
|
||||
|
||||
def format_context(context: dict) -> str:
|
||||
return f"Reason: {context['reason'].rstrip('.')}. Metadata: {EnumFormatter()(context['metadata'])}"
|
||||
|
||||
|
||||
def extract_pages(doc, page_range):
|
||||
def extract_pages(doc: Doc, page_range: range):
|
||||
page_range = range(page_range.start + 1, page_range.stop + 1)
|
||||
pages = map(doc.load_page, page_range)
|
||||
|
||||
yield from pages
|
||||
|
||||
|
||||
def validate_image(image: Image.Image) -> Either:
|
||||
def validate_image(image: Img) -> Either:
|
||||
try:
|
||||
# TODO: stand-in heuristic for testing if image is valid => find cleaner solution (RED-5148)
|
||||
image.resize((100, 100)).convert("RGB")
|
||||
@ -85,7 +96,7 @@ def validate_image(image: Image.Image) -> Either:
|
||||
return Left("Invalid image.")
|
||||
|
||||
|
||||
def extract_valid_metadata(doc: fitz.fitz.Document, page: fitz.fitz.Page):
|
||||
def extract_valid_metadata(doc: Doc, page: Pag) -> List[dict]:
|
||||
return compose(
|
||||
list,
|
||||
partial(add_alpha_channel_info, doc),
|
||||
@ -94,14 +105,14 @@ def extract_valid_metadata(doc: fitz.fitz.Document, page: fitz.fitz.Page):
|
||||
)(page)
|
||||
|
||||
|
||||
def metadatum_to_image_metadata_pair(doc, metadatum: dict) -> Either:
|
||||
maybe_image = xref_to_maybe_image(doc, metadatum[Info.XREF]).bind(validate_image)
|
||||
maybe_image_metadata_pair = make_maybe_image_metadata_pair(maybe_image, Right(metadatum))
|
||||
return maybe_image_metadata_pair
|
||||
def metadatum_to_image_metadata_pair(doc: Doc, metadatum: dict) -> Either:
|
||||
image: Either = xref_to_image(doc, metadatum[Info.XREF]).bind(validate_image)
|
||||
image_metadata_pair: Either = make_eithered_image_metadata_pair(image, Right(metadatum))
|
||||
return image_metadata_pair
|
||||
|
||||
|
||||
def add_alpha_channel_info(doc, metadata):
|
||||
def add_alpha_value_to_metadatum(metadatum):
|
||||
def add_alpha_channel_info(doc: Doc, metadata: Iterable[dict]) -> Iterable[dict]:
|
||||
def add_alpha_value_to_metadatum(metadatum: dict) -> dict:
|
||||
alpha = metadatum_to_alpha_value(metadatum)
|
||||
return {**metadatum, Info.ALPHA: alpha}
|
||||
|
||||
@ -111,11 +122,11 @@ def add_alpha_channel_info(doc, metadata):
|
||||
yield from map(add_alpha_value_to_metadatum, metadata)
|
||||
|
||||
|
||||
def filter_valid_metadata(metadata):
|
||||
def filter_valid_metadata(metadata: Iterable[dict]) -> Iterable[dict]:
|
||||
yield from compose(filter_out_tiny_images, filter_out_invalid_metadata)(metadata)
|
||||
|
||||
|
||||
def get_metadata_for_images_on_page(page: fitz.Page):
|
||||
def get_metadata_for_images_on_page(page: fitz.Page) -> Iterable[dict]:
|
||||
metadata = map(get_image_metadata, get_image_infos(page))
|
||||
metadata = add_page_metadata(page, metadata)
|
||||
|
||||
@ -123,14 +134,14 @@ def get_metadata_for_images_on_page(page: fitz.Page):
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def xref_to_maybe_image(doc, xref) -> Either:
|
||||
def xref_to_image(doc: Doc, xref: int) -> Either:
|
||||
try:
|
||||
return Right(extract_image(doc, xref))
|
||||
except BadXref:
|
||||
return Left("Bad xref.")
|
||||
|
||||
|
||||
def make_maybe_image_metadata_pair(image: Either, metadata: Either):
|
||||
def make_eithered_image_metadata_pair(image: Either, metadata: Either) -> Either:
|
||||
"""Reference: haskell.org/tutorial/monads.html"""
|
||||
|
||||
def context(value):
|
||||
@ -160,11 +171,11 @@ def make_maybe_image_metadata_pair(image: Either, metadata: Either):
|
||||
|
||||
|
||||
@curry(2)
|
||||
def make_image_metadata_pair(image: Image.Image, metadatum: dict) -> ImageMetadataPair:
|
||||
def make_image_metadata_pair(image: Img, metadatum: dict) -> ImageMetadataPair:
|
||||
return ImageMetadataPair(image, metadatum)
|
||||
|
||||
|
||||
def extract_image(doc, xref) -> Image.Image:
|
||||
def extract_image(doc: Doc, xref: int) -> Img:
|
||||
try:
|
||||
pixmap = fitz.Pixmap(doc, xref)
|
||||
except ValueError as err:
|
||||
@ -177,12 +188,12 @@ def extract_image(doc, xref) -> Image.Image:
|
||||
return Image.fromarray(array)
|
||||
|
||||
|
||||
def has_alpha_channel(doc, xref):
|
||||
def has_alpha_channel(doc: Doc, xref: int):
|
||||
|
||||
maybe_image = load_image_handle_from_xref(doc, xref)
|
||||
maybe_smask = maybe_image["smask"] if maybe_image else None
|
||||
|
||||
if maybe_smask:
|
||||
if maybe_smask: # Use monad
|
||||
return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)])
|
||||
else:
|
||||
try:
|
||||
@ -192,11 +203,11 @@ def has_alpha_channel(doc, xref):
|
||||
return False
|
||||
|
||||
|
||||
def filter_out_tiny_images(metadata):
|
||||
def filter_out_tiny_images(metadata: Iterable[dict]) -> Iterable[dict]:
|
||||
yield from filterfalse(tiny, metadata)
|
||||
|
||||
|
||||
def filter_out_invalid_metadata(metadata):
|
||||
def filter_out_invalid_metadata(metadata: Iterable[dict]) -> Iterable[dict]:
|
||||
def __validate_box(box):
|
||||
try:
|
||||
return validate_box(box)
|
||||
@ -206,7 +217,7 @@ def filter_out_invalid_metadata(metadata):
|
||||
yield from keep(__validate_box, metadata)
|
||||
|
||||
|
||||
def get_image_metadata(image_info):
|
||||
def get_image_metadata(image_info: dict) -> dict:
|
||||
|
||||
xref, coords = itemgetter("xref", "bbox")(image_info)
|
||||
x1, y1, x2, y2 = map(rounder, coords)
|
||||
@ -226,11 +237,11 @@ def get_image_metadata(image_info):
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_image_infos(page: fitz.Page) -> List[dict]:
|
||||
def get_image_infos(page: Pag) -> List[dict]:
|
||||
return page.get_image_info(xrefs=True)
|
||||
|
||||
|
||||
def add_page_metadata(page, metadata):
|
||||
def add_page_metadata(page: Pag, metadata: Iterable[dict]) -> Iterable[dict]:
|
||||
yield from map(partial(merge, get_page_metadata(page)), metadata)
|
||||
|
||||
|
||||
@ -250,11 +261,11 @@ def normalize_channels(array: np.ndarray):
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def load_image_handle_from_xref(doc, xref):
|
||||
def load_image_handle_from_xref(doc: Doc, xref: int) -> Union[dict, None]: # TODO: use Monad
|
||||
return doc.extract_image(xref)
|
||||
|
||||
|
||||
def get_page_metadata(page):
|
||||
def get_page_metadata(page: Pag) -> dict:
|
||||
page_width, page_height = map(rounder, page.mediabox_size)
|
||||
|
||||
return {
|
||||
@ -267,14 +278,14 @@ def get_page_metadata(page):
|
||||
rounder = rcompose(round, int)
|
||||
|
||||
|
||||
def tiny(metadatum):
|
||||
def tiny(metadatum: dict) -> bool:
|
||||
return metadatum[Info.WIDTH] * metadatum[Info.HEIGHT] <= 4
|
||||
|
||||
|
||||
def clear_caches():
|
||||
def clear_caches() -> None:
|
||||
get_image_infos.cache_clear()
|
||||
load_image_handle_from_xref.cache_clear()
|
||||
xref_to_maybe_image.cache_clear()
|
||||
xref_to_image.cache_clear()
|
||||
|
||||
|
||||
atexit.register(clear_caches)
|
||||
|
||||
@ -18,7 +18,7 @@ from image_prediction.image_extractor.extractors.parsable import (
|
||||
get_image_infos,
|
||||
ParsablePDFImageExtractor,
|
||||
extract_valid_metadata,
|
||||
xref_to_maybe_image,
|
||||
xref_to_image,
|
||||
extract_image,
|
||||
)
|
||||
from image_prediction.info import Info
|
||||
@ -98,4 +98,4 @@ def test_bad_xref_handling(bad_xref_pdf, dvc_test_data):
|
||||
with pytest.raises(BadXref):
|
||||
extract_image(doc, xref)
|
||||
|
||||
assert xref_to_maybe_image(doc, xref).is_left()
|
||||
assert xref_to_image(doc, xref).is_left()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user