From 03e7b00cfdcece62dad34c01326ea604854228ff Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Thu, 14 Apr 2022 12:20:05 +0200 Subject: [PATCH] refactoring --- .../classifier/image_classifier.py | 5 +- .../extractor_classifier.py | 5 +- .../image_extractor/extractors/parsable.py | 104 ++++++++++-------- image_prediction/utils/generic.py | 11 +- test/exploration_tests/funcy_test.py | 28 ++++- test/unit_tests/image_classifier_test.py | 27 ----- 6 files changed, 94 insertions(+), 86 deletions(-) diff --git a/image_prediction/classifier/image_classifier.py b/image_prediction/classifier/image_classifier.py index 1b9ca84..f01cfd4 100644 --- a/image_prediction/classifier/image_classifier.py +++ b/image_prediction/classifier/image_classifier.py @@ -2,13 +2,12 @@ from itertools import chain from typing import Iterable from PIL.Image import Image -from funcy import rcompose +from funcy import rcompose, chunks from image_prediction.classifier.classifier import Classifier from image_prediction.estimator.preprocessor.preprocessor import Preprocessor from image_prediction.estimator.preprocessor.preprocessors.identity import IdentityPreprocessor from image_prediction.utils import get_logger -from image_prediction.utils.generic import chunk_iterable logger = get_logger() @@ -24,7 +23,7 @@ class ImageClassifier: self.pipe = rcompose(self.preprocessor, self.estimator) def predict(self, images: Iterable[Image], batch_size=16): - batches = chunk_iterable(images, chunk_size=batch_size) + batches = chunks(batch_size, images) predictions = chain.from_iterable(map(self.pipe, batches)) return predictions diff --git a/image_prediction/extractor_classifier/extractor_classifier.py b/image_prediction/extractor_classifier/extractor_classifier.py index 9c4c774..ba97b0e 100644 --- a/image_prediction/extractor_classifier/extractor_classifier.py +++ b/image_prediction/extractor_classifier/extractor_classifier.py @@ -1,9 +1,10 @@ from itertools import chain from typing import Iterable +from funcy import chunks + from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.image_extractor.extractor import ImageExtractor -from image_prediction.utils.generic import chunk_iterable class ExtractorClassifier: @@ -26,6 +27,6 @@ class ExtractorClassifier: def __call__(self, obj, **kwargs) -> Iterable[dict]: image_metadata_pairs = self.extractor(obj, **kwargs) - batches = chunk_iterable(image_metadata_pairs, chunk_size=16) + batches = chunks(16, image_metadata_pairs) predictions = chain.from_iterable(map(self.__process_batch, batches)) return predictions diff --git a/image_prediction/image_extractor/extractors/parsable.py b/image_prediction/image_extractor/extractors/parsable.py index 9de0884..848b60d 100644 --- a/image_prediction/image_extractor/extractors/parsable.py +++ b/image_prediction/image_extractor/extractors/parsable.py @@ -1,13 +1,13 @@ import atexit import io from functools import partial, lru_cache -from itertools import chain, starmap, filterfalse +from itertools import chain, starmap, filterfalse, repeat from operator import itemgetter from typing import List import fitz from PIL import Image -from funcy import rcompose, merge +from funcy import rcompose, merge, zipdict from tqdm import tqdm from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair @@ -74,48 +74,19 @@ def get_metadata_for_images_on_page(doc, page: fitz.Page): metadata = map(get_image_metadata, get_image_infos(page)) metadata = validate_coords_and_passthrough(metadata) - metadata = filterfalse(tiny, metadata) + metadata = filter_out_tiny_images(metadata) metadata = validate_size_and_passthrough(metadata) - metadata = map(partial(merge, get_page_metadata(page)), metadata) + metadata = add_page_metadata(page, metadata) - xrefs = map(itemgetter("xref"), get_image_infos(page)) - alpha = map(partial(has_alpha_channel, doc), xrefs) - alpha = ({Info.ALPHA: a} for a in alpha) - metadata = list(starmap(merge, zip(alpha, metadata))) + metadata = add_alpha_channel_info(doc, page, metadata) yield from metadata -def clear_caches(): - get_image_infos.cache_clear() - load_image_handle_from_xref.cache_clear() - get_images_on_page.cache_clear() - xref_to_image.cache_clear() - - -def validate_coords_and_passthrough(metadata): - yield from map(validate_box_coords, metadata) - - -def validate_size_and_passthrough(metadata): - yield from map(validate_box_size, metadata) - - @lru_cache(maxsize=None) -def load_image_handle_from_xref(doc, xref): - return doc.extract_image(xref) - - -def has_alpha_channel(doc, xref): - - maybe_image = load_image_handle_from_xref(doc, xref) - maybe_smask = maybe_image["smask"] if maybe_image else None - - if maybe_smask: - return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)]) - else: - return bool(fitz.Pixmap(doc, xref).alpha) +def get_image_infos(page: fitz.Page) -> List[dict]: + return page.get_image_info(xrefs=True) @lru_cache(maxsize=None) @@ -124,11 +95,6 @@ def xref_to_image(doc, xref) -> Image: return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None -@lru_cache(maxsize=None) -def get_image_infos(page: fitz.Page) -> List[dict]: - return page.get_image_info(xrefs=True) - - def get_image_metadata(image_info): x1, y1, x2, y2 = map(rounder, image_info["bbox"]) @@ -146,8 +112,38 @@ def get_image_metadata(image_info): } -def tiny(metadata): - return metadata[Info.WIDTH] * metadata[Info.HEIGHT] <= 4 +def validate_coords_and_passthrough(metadata): + yield from map(validate_box_coords, metadata) + + +def filter_out_tiny_images(metadata): + return filterfalse(tiny, metadata) + + +def validate_size_and_passthrough(metadata): + yield from map(validate_box_size, metadata) + + +def add_page_metadata(page, metadata): + return map(partial(merge, get_page_metadata(page)), metadata) + + +def add_alpha_channel_info(doc, page, metadata): + xrefs = map(itemgetter("xref"), get_image_infos(page)) + alpha = map(partial(has_alpha_channel, doc), xrefs) + alpha = ({Info.ALPHA: a} for a in alpha) + # alpha = map(dict, zip(repeat(Info.ALPHA), alpha)) + metadata = starmap(merge, zip(alpha, metadata)) + + return metadata + + +@lru_cache(maxsize=None) +def load_image_handle_from_xref(doc, xref): + return doc.extract_image(xref) + + +rounder = rcompose(round, int) def get_page_metadata(page): @@ -160,6 +156,26 @@ def get_page_metadata(page): } -rounder = rcompose(round, int) +def has_alpha_channel(doc, xref): + + maybe_image = load_image_handle_from_xref(doc, xref) + maybe_smask = maybe_image["smask"] if maybe_image else None + + if maybe_smask: + return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)]) + else: + return bool(fitz.Pixmap(doc, xref).alpha) + + +def tiny(metadata): + return metadata[Info.WIDTH] * metadata[Info.HEIGHT] <= 4 + + +def clear_caches(): + get_image_infos.cache_clear() + load_image_handle_from_xref.cache_clear() + get_images_on_page.cache_clear() + xref_to_image.cache_clear() + atexit.register(clear_caches) diff --git a/image_prediction/utils/generic.py b/image_prediction/utils/generic.py index 4b7a2b0..2ef1522 100644 --- a/image_prediction/utils/generic.py +++ b/image_prediction/utils/generic.py @@ -1,14 +1,7 @@ -from itertools import takewhile, starmap, islice, repeat -from operator import truth - -from funcy import iterate - - -def chunk_iterable(iterable, chunk_size): - return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size))))) +from funcy import iterate, chunks def until(cond, func, *args, **kwargs): - for a, b in chunk_iterable(iterate(func, *args, **kwargs), chunk_size=2): + for a, b in chunks(2, iterate(func, *args, **kwargs)): if cond(a, b): return a diff --git a/test/exploration_tests/funcy_test.py b/test/exploration_tests/funcy_test.py index 5032396..30c2cef 100644 --- a/test/exploration_tests/funcy_test.py +++ b/test/exploration_tests/funcy_test.py @@ -1,6 +1,32 @@ -from funcy import rcompose +import pytest +from funcy import rcompose, chunks def test_rcompose(): f = rcompose(lambda x: x ** 2, str, lambda x: x * 2) assert f(3) == "99" + + +def test_chunk_iterable_exact_split(): + a, b = chunks(5, iter(range(10))) + assert a == list(range(5)) + assert b == list(range(5, 10)) + + +def test_chunk_iterable_no_split(): + a = next(chunks(10, iter(range(10)))) + assert a == list(range(10)) + + +def test_chunk_iterable_last_partial(): + a, b, c, d = chunks(3, iter(range(10))) + assert d == [9] + + +def test_chunk_iterable_empty(): + with pytest.raises(StopIteration): + next(chunks(3, iter(range(0)))) + + +def test_chunk_iterable_less_than_chunk_size_elements(): + assert next(chunks(5, iter(range(2)))) == [0, 1] diff --git a/test/unit_tests/image_classifier_test.py b/test/unit_tests/image_classifier_test.py index 801030a..5cffdf3 100644 --- a/test/unit_tests/image_classifier_test.py +++ b/test/unit_tests/image_classifier_test.py @@ -1,34 +1,7 @@ import pytest -from image_prediction.utils.generic import chunk_iterable - @pytest.mark.parametrize("estimator_type", ["mock", "keras"]) def test_predict(image_classifier, images, batch_of_expected_string_labels): predictions = list(image_classifier.predict(images)) assert predictions == batch_of_expected_string_labels - - -def test_chunk_iterable_exact_split(): - a, b = chunk_iterable(range(10), chunk_size=5) - assert a == tuple(range(5)) - assert b == tuple(range(5, 10)) - - -def test_chunk_iterable_no_split(): - a = next(chunk_iterable(range(10), chunk_size=10)) - assert a == tuple(range(10)) - - -def test_chunk_iterable_last_partial(): - a, b, c, d = chunk_iterable(range(10), chunk_size=3) - assert d == (9,) - - -def test_chunk_iterable_empty(): - with pytest.raises(StopIteration): - next(chunk_iterable(range(0), chunk_size=3)) - - -def test_chunk_iterable_less_than_chunk_size_elements(): - assert next(chunk_iterable(range(2), chunk_size=5)) == (0, 1)