Refactoring
- Rename - Refactor image extraction functions
This commit is contained in:
parent
76b1b0ca24
commit
e99e97e23f
@ -1,9 +1,9 @@
|
|||||||
import atexit
|
import atexit
|
||||||
from _operator import itemgetter
|
from _operator import itemgetter
|
||||||
from functools import partial, lru_cache
|
from functools import partial, lru_cache, singledispatch
|
||||||
from itertools import chain, filterfalse
|
from itertools import chain, filterfalse
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import List, Union, Iterable
|
from typing import List, Union, Iterable, Any
|
||||||
|
|
||||||
import fitz
|
import fitz
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -19,14 +19,15 @@ from image_prediction.info import Info
|
|||||||
from image_prediction.stitching.stitching import stitch_pairs
|
from image_prediction.stitching.stitching import stitch_pairs
|
||||||
from image_prediction.stitching.utils import validate_box
|
from image_prediction.stitching.utils import validate_box
|
||||||
from image_prediction.utils import get_logger
|
from image_prediction.utils import get_logger
|
||||||
from image_prediction.utils.generic import bottom, left, right
|
from image_prediction.utils.generic import bottom, left, right, lift
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
Doc = fitz.fitz.Document
|
Doc = fitz.fitz.Document
|
||||||
Pag = fitz.fitz.Page
|
Pge = fitz.fitz.Page
|
||||||
Img = Image.Image
|
Img = Image.Image
|
||||||
|
Pxm = fitz.fitz.Pixmap
|
||||||
|
|
||||||
|
|
||||||
class ParsablePDFImageExtractor(ImageExtractor):
|
class ParsablePDFImageExtractor(ImageExtractor):
|
||||||
@ -51,7 +52,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
|||||||
|
|
||||||
yield from image_metadata_pairs
|
yield from image_metadata_pairs
|
||||||
|
|
||||||
def __process_images_on_page(self, page: Pag):
|
def __process_images_on_page(self, page: Pge):
|
||||||
metadata = extract_valid_metadata(self.doc, page)
|
metadata = extract_valid_metadata(self.doc, page)
|
||||||
|
|
||||||
either_image_metadata_pair_or_error_per_image = map(self.__metadatum_to_image_metadata_pair, metadata)
|
either_image_metadata_pair_or_error_per_image = map(self.__metadatum_to_image_metadata_pair, metadata)
|
||||||
@ -96,7 +97,7 @@ def validate_image(image: Img) -> Either:
|
|||||||
return Left("Invalid image.")
|
return Left("Invalid image.")
|
||||||
|
|
||||||
|
|
||||||
def extract_valid_metadata(doc: Doc, page: Pag) -> List[dict]:
|
def extract_valid_metadata(doc: Doc, page: Pge) -> List[dict]:
|
||||||
return compose(
|
return compose(
|
||||||
list,
|
list,
|
||||||
partial(add_alpha_channel_info, doc),
|
partial(add_alpha_channel_info, doc),
|
||||||
@ -106,8 +107,8 @@ def extract_valid_metadata(doc: Doc, page: Pag) -> List[dict]:
|
|||||||
|
|
||||||
|
|
||||||
def metadatum_to_image_metadata_pair(doc: Doc, metadatum: dict) -> Either:
|
def metadatum_to_image_metadata_pair(doc: Doc, metadatum: dict) -> Either:
|
||||||
image: Either = xref_to_image(doc, metadatum[Info.XREF]).bind(validate_image)
|
image: Either = eith_extract_image(doc, metadatum[Info.XREF]).bind(validate_image)
|
||||||
image_metadata_pair: Either = make_eithered_image_metadata_pair(image, Right(metadatum))
|
image_metadata_pair: Either = eith_make_image_metadata_pair(image, Right(metadatum))
|
||||||
return image_metadata_pair
|
return image_metadata_pair
|
||||||
|
|
||||||
|
|
||||||
@ -127,21 +128,24 @@ def filter_valid_metadata(metadata: Iterable[dict]) -> Iterable[dict]:
|
|||||||
|
|
||||||
|
|
||||||
def get_metadata_for_images_on_page(page: fitz.Page) -> Iterable[dict]:
|
def get_metadata_for_images_on_page(page: fitz.Page) -> Iterable[dict]:
|
||||||
metadata = map(get_image_metadata, get_image_infos(page))
|
metadata = compose(
|
||||||
metadata = add_page_metadata(page, metadata)
|
partial(add_page_metadata, page),
|
||||||
|
lift(get_image_metadata),
|
||||||
|
get_image_infos,
|
||||||
|
)(page)
|
||||||
|
|
||||||
yield from metadata
|
yield from metadata
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def xref_to_image(doc: Doc, xref: int) -> Either:
|
def eith_extract_image(doc: Doc, xref: int) -> Either:
|
||||||
try:
|
try:
|
||||||
return Right(extract_image(doc, xref))
|
return Right(extract_image(doc, xref))
|
||||||
except BadXref:
|
except BadXref:
|
||||||
return Left("Bad xref.")
|
return Left("Bad xref.")
|
||||||
|
|
||||||
|
|
||||||
def make_eithered_image_metadata_pair(image: Either, metadata: Either) -> Either:
|
def eith_make_image_metadata_pair(image: Either, metadata: Either) -> Either:
|
||||||
"""Reference: haskell.org/tutorial/monads.html"""
|
"""Reference: haskell.org/tutorial/monads.html"""
|
||||||
|
|
||||||
def context(value):
|
def context(value):
|
||||||
@ -180,18 +184,25 @@ def make_image_metadata_pair(image: Img, metadatum: dict) -> ImageMetadataPair:
|
|||||||
return ImageMetadataPair(image, metadatum)
|
return ImageMetadataPair(image, metadatum)
|
||||||
|
|
||||||
|
|
||||||
def extract_image(doc: Doc, xref: int) -> Img:
|
@singledispatch
|
||||||
|
def extract_image(doc: Doc, xref: int) -> Any:
|
||||||
|
return compose(pixmap_to_image, extract_pixmap)(doc, xref)
|
||||||
|
|
||||||
|
|
||||||
|
def pixmap_to_image(pixmap: Pxm) -> Img:
|
||||||
|
array = np.frombuffer(pixmap.samples, dtype=np.uint8).reshape((pixmap.h, pixmap.w, pixmap.n))
|
||||||
|
array = normalize_channels(array)
|
||||||
|
return Image.fromarray(array)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_pixmap(doc: Doc, xref: int) -> Pxm:
|
||||||
try:
|
try:
|
||||||
pixmap = fitz.Pixmap(doc, xref)
|
return fitz.Pixmap(doc, xref)
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
msg = f"Xref {xref} is invalid, skipping extraction."
|
msg = f"Xref {xref} is invalid, skipping extraction."
|
||||||
logger.debug(msg)
|
logger.debug(msg)
|
||||||
raise BadXref(msg) from err
|
raise BadXref(msg) from err
|
||||||
|
|
||||||
array = np.frombuffer(pixmap.samples, dtype=np.uint8).reshape((pixmap.h, pixmap.w, pixmap.n))
|
|
||||||
array = normalize_channels(array)
|
|
||||||
return Image.fromarray(array)
|
|
||||||
|
|
||||||
|
|
||||||
def has_alpha_channel(doc: Doc, xref: int):
|
def has_alpha_channel(doc: Doc, xref: int):
|
||||||
|
|
||||||
@ -199,7 +210,9 @@ def has_alpha_channel(doc: Doc, xref: int):
|
|||||||
maybe_smask = maybe_image["smask"] if maybe_image else None
|
maybe_smask = maybe_image["smask"] if maybe_image else None
|
||||||
|
|
||||||
if maybe_smask: # TODO: Use monad.
|
if maybe_smask: # TODO: Use monad.
|
||||||
return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)])
|
return any(
|
||||||
|
[load_image_handle_from_xref(doc, maybe_smask) is not None, bool(extract_pixmap(doc, maybe_smask).alpha)]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
return bool(fitz.Pixmap(doc, xref).alpha)
|
return bool(fitz.Pixmap(doc, xref).alpha)
|
||||||
@ -242,11 +255,11 @@ def get_image_metadata(image_info: dict) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def get_image_infos(page: Pag) -> List[dict]:
|
def get_image_infos(page: Pge) -> List[dict]:
|
||||||
return page.get_image_info(xrefs=True)
|
return page.get_image_info(xrefs=True)
|
||||||
|
|
||||||
|
|
||||||
def add_page_metadata(page: Pag, metadata: Iterable[dict]) -> Iterable[dict]:
|
def add_page_metadata(page: Pge, metadata: Iterable[dict]) -> Iterable[dict]:
|
||||||
yield from map(partial(merge, get_page_metadata(page)), metadata)
|
yield from map(partial(merge, get_page_metadata(page)), metadata)
|
||||||
|
|
||||||
|
|
||||||
@ -270,7 +283,7 @@ def load_image_handle_from_xref(doc: Doc, xref: int) -> Union[dict, None]: # TO
|
|||||||
return doc.extract_image(xref)
|
return doc.extract_image(xref)
|
||||||
|
|
||||||
|
|
||||||
def get_page_metadata(page: Pag) -> dict:
|
def get_page_metadata(page: Pge) -> dict:
|
||||||
page_width, page_height = map(rounder, page.mediabox_size)
|
page_width, page_height = map(rounder, page.mediabox_size)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -290,7 +303,7 @@ def tiny(metadatum: dict) -> bool:
|
|||||||
def clear_caches() -> None:
|
def clear_caches() -> None:
|
||||||
get_image_infos.cache_clear()
|
get_image_infos.cache_clear()
|
||||||
load_image_handle_from_xref.cache_clear()
|
load_image_handle_from_xref.cache_clear()
|
||||||
xref_to_image.cache_clear()
|
eith_extract_image.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
atexit.register(clear_caches)
|
atexit.register(clear_caches)
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from image_prediction.image_extractor.extractors.parsable import (
|
|||||||
get_image_infos,
|
get_image_infos,
|
||||||
ParsablePDFImageExtractor,
|
ParsablePDFImageExtractor,
|
||||||
extract_valid_metadata,
|
extract_valid_metadata,
|
||||||
xref_to_image,
|
eith_extract_image,
|
||||||
extract_image,
|
extract_image,
|
||||||
)
|
)
|
||||||
from image_prediction.info import Info
|
from image_prediction.info import Info
|
||||||
@ -98,4 +98,4 @@ def test_bad_xref_handling(bad_xref_pdf, dvc_test_data):
|
|||||||
with pytest.raises(BadXref):
|
with pytest.raises(BadXref):
|
||||||
extract_image(doc, xref)
|
extract_image(doc, xref)
|
||||||
|
|
||||||
assert xref_to_image(doc, xref).is_left()
|
assert eith_extract_image(doc, xref).is_left()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user