refactoring

This commit is contained in:
Matthias Bisping 2022-04-14 12:20:05 +02:00
parent 7aee00cb49
commit 03e7b00cfd
6 changed files with 94 additions and 86 deletions

View File

@ -2,13 +2,12 @@ from itertools import chain
from typing import Iterable from typing import Iterable
from PIL.Image import Image from PIL.Image import Image
from funcy import rcompose from funcy import rcompose, chunks
from image_prediction.classifier.classifier import Classifier from image_prediction.classifier.classifier import Classifier
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
from image_prediction.estimator.preprocessor.preprocessors.identity import IdentityPreprocessor from image_prediction.estimator.preprocessor.preprocessors.identity import IdentityPreprocessor
from image_prediction.utils import get_logger from image_prediction.utils import get_logger
from image_prediction.utils.generic import chunk_iterable
logger = get_logger() logger = get_logger()
@ -24,7 +23,7 @@ class ImageClassifier:
self.pipe = rcompose(self.preprocessor, self.estimator) self.pipe = rcompose(self.preprocessor, self.estimator)
def predict(self, images: Iterable[Image], batch_size=16): def predict(self, images: Iterable[Image], batch_size=16):
batches = chunk_iterable(images, chunk_size=batch_size) batches = chunks(batch_size, images)
predictions = chain.from_iterable(map(self.pipe, batches)) predictions = chain.from_iterable(map(self.pipe, batches))
return predictions return predictions

View File

@ -1,9 +1,10 @@
from itertools import chain from itertools import chain
from typing import Iterable from typing import Iterable
from funcy import chunks
from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.image_extractor.extractor import ImageExtractor from image_prediction.image_extractor.extractor import ImageExtractor
from image_prediction.utils.generic import chunk_iterable
class ExtractorClassifier: class ExtractorClassifier:
@ -26,6 +27,6 @@ class ExtractorClassifier:
def __call__(self, obj, **kwargs) -> Iterable[dict]: def __call__(self, obj, **kwargs) -> Iterable[dict]:
image_metadata_pairs = self.extractor(obj, **kwargs) image_metadata_pairs = self.extractor(obj, **kwargs)
batches = chunk_iterable(image_metadata_pairs, chunk_size=16) batches = chunks(16, image_metadata_pairs)
predictions = chain.from_iterable(map(self.__process_batch, batches)) predictions = chain.from_iterable(map(self.__process_batch, batches))
return predictions return predictions

View File

@ -1,13 +1,13 @@
import atexit import atexit
import io import io
from functools import partial, lru_cache from functools import partial, lru_cache
from itertools import chain, starmap, filterfalse from itertools import chain, starmap, filterfalse, repeat
from operator import itemgetter from operator import itemgetter
from typing import List from typing import List
import fitz import fitz
from PIL import Image from PIL import Image
from funcy import rcompose, merge from funcy import rcompose, merge, zipdict
from tqdm import tqdm from tqdm import tqdm
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
@ -74,48 +74,19 @@ def get_metadata_for_images_on_page(doc, page: fitz.Page):
metadata = map(get_image_metadata, get_image_infos(page)) metadata = map(get_image_metadata, get_image_infos(page))
metadata = validate_coords_and_passthrough(metadata) metadata = validate_coords_and_passthrough(metadata)
metadata = filterfalse(tiny, metadata) metadata = filter_out_tiny_images(metadata)
metadata = validate_size_and_passthrough(metadata) metadata = validate_size_and_passthrough(metadata)
metadata = map(partial(merge, get_page_metadata(page)), metadata) metadata = add_page_metadata(page, metadata)
xrefs = map(itemgetter("xref"), get_image_infos(page)) metadata = add_alpha_channel_info(doc, page, metadata)
alpha = map(partial(has_alpha_channel, doc), xrefs)
alpha = ({Info.ALPHA: a} for a in alpha)
metadata = list(starmap(merge, zip(alpha, metadata)))
yield from metadata yield from metadata
def clear_caches():
get_image_infos.cache_clear()
load_image_handle_from_xref.cache_clear()
get_images_on_page.cache_clear()
xref_to_image.cache_clear()
def validate_coords_and_passthrough(metadata):
yield from map(validate_box_coords, metadata)
def validate_size_and_passthrough(metadata):
yield from map(validate_box_size, metadata)
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def load_image_handle_from_xref(doc, xref): def get_image_infos(page: fitz.Page) -> List[dict]:
return doc.extract_image(xref) return page.get_image_info(xrefs=True)
def has_alpha_channel(doc, xref):
maybe_image = load_image_handle_from_xref(doc, xref)
maybe_smask = maybe_image["smask"] if maybe_image else None
if maybe_smask:
return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)])
else:
return bool(fitz.Pixmap(doc, xref).alpha)
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
@ -124,11 +95,6 @@ def xref_to_image(doc, xref) -> Image:
return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None
@lru_cache(maxsize=None)
def get_image_infos(page: fitz.Page) -> List[dict]:
return page.get_image_info(xrefs=True)
def get_image_metadata(image_info): def get_image_metadata(image_info):
x1, y1, x2, y2 = map(rounder, image_info["bbox"]) x1, y1, x2, y2 = map(rounder, image_info["bbox"])
@ -146,8 +112,38 @@ def get_image_metadata(image_info):
} }
def tiny(metadata): def validate_coords_and_passthrough(metadata):
return metadata[Info.WIDTH] * metadata[Info.HEIGHT] <= 4 yield from map(validate_box_coords, metadata)
def filter_out_tiny_images(metadata):
return filterfalse(tiny, metadata)
def validate_size_and_passthrough(metadata):
yield from map(validate_box_size, metadata)
def add_page_metadata(page, metadata):
return map(partial(merge, get_page_metadata(page)), metadata)
def add_alpha_channel_info(doc, page, metadata):
xrefs = map(itemgetter("xref"), get_image_infos(page))
alpha = map(partial(has_alpha_channel, doc), xrefs)
alpha = ({Info.ALPHA: a} for a in alpha)
# alpha = map(dict, zip(repeat(Info.ALPHA), alpha))
metadata = starmap(merge, zip(alpha, metadata))
return metadata
@lru_cache(maxsize=None)
def load_image_handle_from_xref(doc, xref):
return doc.extract_image(xref)
rounder = rcompose(round, int)
def get_page_metadata(page): def get_page_metadata(page):
@ -160,6 +156,26 @@ def get_page_metadata(page):
} }
rounder = rcompose(round, int) def has_alpha_channel(doc, xref):
maybe_image = load_image_handle_from_xref(doc, xref)
maybe_smask = maybe_image["smask"] if maybe_image else None
if maybe_smask:
return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)])
else:
return bool(fitz.Pixmap(doc, xref).alpha)
def tiny(metadata):
return metadata[Info.WIDTH] * metadata[Info.HEIGHT] <= 4
def clear_caches():
get_image_infos.cache_clear()
load_image_handle_from_xref.cache_clear()
get_images_on_page.cache_clear()
xref_to_image.cache_clear()
atexit.register(clear_caches) atexit.register(clear_caches)

View File

@ -1,14 +1,7 @@
from itertools import takewhile, starmap, islice, repeat from funcy import iterate, chunks
from operator import truth
from funcy import iterate
def chunk_iterable(iterable, chunk_size):
return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size)))))
def until(cond, func, *args, **kwargs): def until(cond, func, *args, **kwargs):
for a, b in chunk_iterable(iterate(func, *args, **kwargs), chunk_size=2): for a, b in chunks(2, iterate(func, *args, **kwargs)):
if cond(a, b): if cond(a, b):
return a return a

View File

@ -1,6 +1,32 @@
from funcy import rcompose import pytest
from funcy import rcompose, chunks
def test_rcompose(): def test_rcompose():
f = rcompose(lambda x: x ** 2, str, lambda x: x * 2) f = rcompose(lambda x: x ** 2, str, lambda x: x * 2)
assert f(3) == "99" assert f(3) == "99"
def test_chunk_iterable_exact_split():
a, b = chunks(5, iter(range(10)))
assert a == list(range(5))
assert b == list(range(5, 10))
def test_chunk_iterable_no_split():
a = next(chunks(10, iter(range(10))))
assert a == list(range(10))
def test_chunk_iterable_last_partial():
a, b, c, d = chunks(3, iter(range(10)))
assert d == [9]
def test_chunk_iterable_empty():
with pytest.raises(StopIteration):
next(chunks(3, iter(range(0))))
def test_chunk_iterable_less_than_chunk_size_elements():
assert next(chunks(5, iter(range(2)))) == [0, 1]

View File

@ -1,34 +1,7 @@
import pytest import pytest
from image_prediction.utils.generic import chunk_iterable
@pytest.mark.parametrize("estimator_type", ["mock", "keras"]) @pytest.mark.parametrize("estimator_type", ["mock", "keras"])
def test_predict(image_classifier, images, batch_of_expected_string_labels): def test_predict(image_classifier, images, batch_of_expected_string_labels):
predictions = list(image_classifier.predict(images)) predictions = list(image_classifier.predict(images))
assert predictions == batch_of_expected_string_labels assert predictions == batch_of_expected_string_labels
def test_chunk_iterable_exact_split():
a, b = chunk_iterable(range(10), chunk_size=5)
assert a == tuple(range(5))
assert b == tuple(range(5, 10))
def test_chunk_iterable_no_split():
a = next(chunk_iterable(range(10), chunk_size=10))
assert a == tuple(range(10))
def test_chunk_iterable_last_partial():
a, b, c, d = chunk_iterable(range(10), chunk_size=3)
assert d == (9,)
def test_chunk_iterable_empty():
with pytest.raises(StopIteration):
next(chunk_iterable(range(0), chunk_size=3))
def test_chunk_iterable_less_than_chunk_size_elements():
assert next(chunk_iterable(range(2), chunk_size=5)) == (0, 1)