refactoring

This commit is contained in:
Matthias Bisping 2022-04-14 12:20:05 +02:00
parent 7aee00cb49
commit 03e7b00cfd
6 changed files with 94 additions and 86 deletions

View File

@ -2,13 +2,12 @@ from itertools import chain
from typing import Iterable
from PIL.Image import Image
from funcy import rcompose
from funcy import rcompose, chunks
from image_prediction.classifier.classifier import Classifier
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
from image_prediction.estimator.preprocessor.preprocessors.identity import IdentityPreprocessor
from image_prediction.utils import get_logger
from image_prediction.utils.generic import chunk_iterable
logger = get_logger()
@ -24,7 +23,7 @@ class ImageClassifier:
self.pipe = rcompose(self.preprocessor, self.estimator)
def predict(self, images: Iterable[Image], batch_size=16):
batches = chunk_iterable(images, chunk_size=batch_size)
batches = chunks(batch_size, images)
predictions = chain.from_iterable(map(self.pipe, batches))
return predictions

View File

@ -1,9 +1,10 @@
from itertools import chain
from typing import Iterable
from funcy import chunks
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.image_extractor.extractor import ImageExtractor
from image_prediction.utils.generic import chunk_iterable
class ExtractorClassifier:
@ -26,6 +27,6 @@ class ExtractorClassifier:
def __call__(self, obj, **kwargs) -> Iterable[dict]:
image_metadata_pairs = self.extractor(obj, **kwargs)
batches = chunk_iterable(image_metadata_pairs, chunk_size=16)
batches = chunks(16, image_metadata_pairs)
predictions = chain.from_iterable(map(self.__process_batch, batches))
return predictions

View File

@ -1,13 +1,13 @@
import atexit
import io
from functools import partial, lru_cache
from itertools import chain, starmap, filterfalse
from itertools import chain, starmap, filterfalse, repeat
from operator import itemgetter
from typing import List
import fitz
from PIL import Image
from funcy import rcompose, merge
from funcy import rcompose, merge, zipdict
from tqdm import tqdm
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
@ -74,48 +74,19 @@ def get_metadata_for_images_on_page(doc, page: fitz.Page):
metadata = map(get_image_metadata, get_image_infos(page))
metadata = validate_coords_and_passthrough(metadata)
metadata = filterfalse(tiny, metadata)
metadata = filter_out_tiny_images(metadata)
metadata = validate_size_and_passthrough(metadata)
metadata = map(partial(merge, get_page_metadata(page)), metadata)
metadata = add_page_metadata(page, metadata)
xrefs = map(itemgetter("xref"), get_image_infos(page))
alpha = map(partial(has_alpha_channel, doc), xrefs)
alpha = ({Info.ALPHA: a} for a in alpha)
metadata = list(starmap(merge, zip(alpha, metadata)))
metadata = add_alpha_channel_info(doc, page, metadata)
yield from metadata
def clear_caches():
get_image_infos.cache_clear()
load_image_handle_from_xref.cache_clear()
get_images_on_page.cache_clear()
xref_to_image.cache_clear()
def validate_coords_and_passthrough(metadata):
yield from map(validate_box_coords, metadata)
def validate_size_and_passthrough(metadata):
yield from map(validate_box_size, metadata)
@lru_cache(maxsize=None)
def load_image_handle_from_xref(doc, xref):
return doc.extract_image(xref)
def has_alpha_channel(doc, xref):
maybe_image = load_image_handle_from_xref(doc, xref)
maybe_smask = maybe_image["smask"] if maybe_image else None
if maybe_smask:
return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)])
else:
return bool(fitz.Pixmap(doc, xref).alpha)
def get_image_infos(page: fitz.Page) -> List[dict]:
return page.get_image_info(xrefs=True)
@lru_cache(maxsize=None)
@ -124,11 +95,6 @@ def xref_to_image(doc, xref) -> Image:
return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None
@lru_cache(maxsize=None)
def get_image_infos(page: fitz.Page) -> List[dict]:
return page.get_image_info(xrefs=True)
def get_image_metadata(image_info):
x1, y1, x2, y2 = map(rounder, image_info["bbox"])
@ -146,8 +112,38 @@ def get_image_metadata(image_info):
}
def tiny(metadata):
return metadata[Info.WIDTH] * metadata[Info.HEIGHT] <= 4
def validate_coords_and_passthrough(metadata):
yield from map(validate_box_coords, metadata)
def filter_out_tiny_images(metadata):
return filterfalse(tiny, metadata)
def validate_size_and_passthrough(metadata):
yield from map(validate_box_size, metadata)
def add_page_metadata(page, metadata):
return map(partial(merge, get_page_metadata(page)), metadata)
def add_alpha_channel_info(doc, page, metadata):
xrefs = map(itemgetter("xref"), get_image_infos(page))
alpha = map(partial(has_alpha_channel, doc), xrefs)
alpha = ({Info.ALPHA: a} for a in alpha)
# alpha = map(dict, zip(repeat(Info.ALPHA), alpha))
metadata = starmap(merge, zip(alpha, metadata))
return metadata
@lru_cache(maxsize=None)
def load_image_handle_from_xref(doc, xref):
return doc.extract_image(xref)
rounder = rcompose(round, int)
def get_page_metadata(page):
@ -160,6 +156,26 @@ def get_page_metadata(page):
}
rounder = rcompose(round, int)
def has_alpha_channel(doc, xref):
maybe_image = load_image_handle_from_xref(doc, xref)
maybe_smask = maybe_image["smask"] if maybe_image else None
if maybe_smask:
return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)])
else:
return bool(fitz.Pixmap(doc, xref).alpha)
def tiny(metadata):
return metadata[Info.WIDTH] * metadata[Info.HEIGHT] <= 4
def clear_caches():
get_image_infos.cache_clear()
load_image_handle_from_xref.cache_clear()
get_images_on_page.cache_clear()
xref_to_image.cache_clear()
atexit.register(clear_caches)

View File

@ -1,14 +1,7 @@
from itertools import takewhile, starmap, islice, repeat
from operator import truth
from funcy import iterate
def chunk_iterable(iterable, chunk_size):
return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size)))))
from funcy import iterate, chunks
def until(cond, func, *args, **kwargs):
for a, b in chunk_iterable(iterate(func, *args, **kwargs), chunk_size=2):
for a, b in chunks(2, iterate(func, *args, **kwargs)):
if cond(a, b):
return a

View File

@ -1,6 +1,32 @@
from funcy import rcompose
import pytest
from funcy import rcompose, chunks
def test_rcompose():
f = rcompose(lambda x: x ** 2, str, lambda x: x * 2)
assert f(3) == "99"
def test_chunk_iterable_exact_split():
a, b = chunks(5, iter(range(10)))
assert a == list(range(5))
assert b == list(range(5, 10))
def test_chunk_iterable_no_split():
a = next(chunks(10, iter(range(10))))
assert a == list(range(10))
def test_chunk_iterable_last_partial():
a, b, c, d = chunks(3, iter(range(10)))
assert d == [9]
def test_chunk_iterable_empty():
with pytest.raises(StopIteration):
next(chunks(3, iter(range(0))))
def test_chunk_iterable_less_than_chunk_size_elements():
assert next(chunks(5, iter(range(2)))) == [0, 1]

View File

@ -1,34 +1,7 @@
import pytest
from image_prediction.utils.generic import chunk_iterable
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
def test_predict(image_classifier, images, batch_of_expected_string_labels):
predictions = list(image_classifier.predict(images))
assert predictions == batch_of_expected_string_labels
def test_chunk_iterable_exact_split():
a, b = chunk_iterable(range(10), chunk_size=5)
assert a == tuple(range(5))
assert b == tuple(range(5, 10))
def test_chunk_iterable_no_split():
a = next(chunk_iterable(range(10), chunk_size=10))
assert a == tuple(range(10))
def test_chunk_iterable_last_partial():
a, b, c, d = chunk_iterable(range(10), chunk_size=3)
assert d == (9,)
def test_chunk_iterable_empty():
with pytest.raises(StopIteration):
next(chunk_iterable(range(0), chunk_size=3))
def test_chunk_iterable_less_than_chunk_size_elements():
assert next(chunk_iterable(range(2), chunk_size=5)) == (0, 1)