refactoring
This commit is contained in:
parent
7aee00cb49
commit
03e7b00cfd
@ -2,13 +2,12 @@ from itertools import chain
|
||||
from typing import Iterable
|
||||
|
||||
from PIL.Image import Image
|
||||
from funcy import rcompose
|
||||
from funcy import rcompose, chunks
|
||||
|
||||
from image_prediction.classifier.classifier import Classifier
|
||||
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
|
||||
from image_prediction.estimator.preprocessor.preprocessors.identity import IdentityPreprocessor
|
||||
from image_prediction.utils import get_logger
|
||||
from image_prediction.utils.generic import chunk_iterable
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
@ -24,7 +23,7 @@ class ImageClassifier:
|
||||
self.pipe = rcompose(self.preprocessor, self.estimator)
|
||||
|
||||
def predict(self, images: Iterable[Image], batch_size=16):
|
||||
batches = chunk_iterable(images, chunk_size=batch_size)
|
||||
batches = chunks(batch_size, images)
|
||||
predictions = chain.from_iterable(map(self.pipe, batches))
|
||||
return predictions
|
||||
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
from itertools import chain
|
||||
from typing import Iterable
|
||||
|
||||
from funcy import chunks
|
||||
|
||||
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||
from image_prediction.image_extractor.extractor import ImageExtractor
|
||||
from image_prediction.utils.generic import chunk_iterable
|
||||
|
||||
|
||||
class ExtractorClassifier:
|
||||
@ -26,6 +27,6 @@ class ExtractorClassifier:
|
||||
|
||||
def __call__(self, obj, **kwargs) -> Iterable[dict]:
|
||||
image_metadata_pairs = self.extractor(obj, **kwargs)
|
||||
batches = chunk_iterable(image_metadata_pairs, chunk_size=16)
|
||||
batches = chunks(16, image_metadata_pairs)
|
||||
predictions = chain.from_iterable(map(self.__process_batch, batches))
|
||||
return predictions
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
import atexit
|
||||
import io
|
||||
from functools import partial, lru_cache
|
||||
from itertools import chain, starmap, filterfalse
|
||||
from itertools import chain, starmap, filterfalse, repeat
|
||||
from operator import itemgetter
|
||||
from typing import List
|
||||
|
||||
import fitz
|
||||
from PIL import Image
|
||||
from funcy import rcompose, merge
|
||||
from funcy import rcompose, merge, zipdict
|
||||
from tqdm import tqdm
|
||||
|
||||
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
|
||||
@ -74,48 +74,19 @@ def get_metadata_for_images_on_page(doc, page: fitz.Page):
|
||||
metadata = map(get_image_metadata, get_image_infos(page))
|
||||
metadata = validate_coords_and_passthrough(metadata)
|
||||
|
||||
metadata = filterfalse(tiny, metadata)
|
||||
metadata = filter_out_tiny_images(metadata)
|
||||
metadata = validate_size_and_passthrough(metadata)
|
||||
|
||||
metadata = map(partial(merge, get_page_metadata(page)), metadata)
|
||||
metadata = add_page_metadata(page, metadata)
|
||||
|
||||
xrefs = map(itemgetter("xref"), get_image_infos(page))
|
||||
alpha = map(partial(has_alpha_channel, doc), xrefs)
|
||||
alpha = ({Info.ALPHA: a} for a in alpha)
|
||||
metadata = list(starmap(merge, zip(alpha, metadata)))
|
||||
metadata = add_alpha_channel_info(doc, page, metadata)
|
||||
|
||||
yield from metadata
|
||||
|
||||
|
||||
def clear_caches():
|
||||
get_image_infos.cache_clear()
|
||||
load_image_handle_from_xref.cache_clear()
|
||||
get_images_on_page.cache_clear()
|
||||
xref_to_image.cache_clear()
|
||||
|
||||
|
||||
def validate_coords_and_passthrough(metadata):
|
||||
yield from map(validate_box_coords, metadata)
|
||||
|
||||
|
||||
def validate_size_and_passthrough(metadata):
|
||||
yield from map(validate_box_size, metadata)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def load_image_handle_from_xref(doc, xref):
|
||||
return doc.extract_image(xref)
|
||||
|
||||
|
||||
def has_alpha_channel(doc, xref):
|
||||
|
||||
maybe_image = load_image_handle_from_xref(doc, xref)
|
||||
maybe_smask = maybe_image["smask"] if maybe_image else None
|
||||
|
||||
if maybe_smask:
|
||||
return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)])
|
||||
else:
|
||||
return bool(fitz.Pixmap(doc, xref).alpha)
|
||||
def get_image_infos(page: fitz.Page) -> List[dict]:
|
||||
return page.get_image_info(xrefs=True)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
@ -124,11 +95,6 @@ def xref_to_image(doc, xref) -> Image:
|
||||
return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_image_infos(page: fitz.Page) -> List[dict]:
|
||||
return page.get_image_info(xrefs=True)
|
||||
|
||||
|
||||
def get_image_metadata(image_info):
|
||||
|
||||
x1, y1, x2, y2 = map(rounder, image_info["bbox"])
|
||||
@ -146,8 +112,38 @@ def get_image_metadata(image_info):
|
||||
}
|
||||
|
||||
|
||||
def tiny(metadata):
|
||||
return metadata[Info.WIDTH] * metadata[Info.HEIGHT] <= 4
|
||||
def validate_coords_and_passthrough(metadata):
|
||||
yield from map(validate_box_coords, metadata)
|
||||
|
||||
|
||||
def filter_out_tiny_images(metadata):
|
||||
return filterfalse(tiny, metadata)
|
||||
|
||||
|
||||
def validate_size_and_passthrough(metadata):
|
||||
yield from map(validate_box_size, metadata)
|
||||
|
||||
|
||||
def add_page_metadata(page, metadata):
|
||||
return map(partial(merge, get_page_metadata(page)), metadata)
|
||||
|
||||
|
||||
def add_alpha_channel_info(doc, page, metadata):
|
||||
xrefs = map(itemgetter("xref"), get_image_infos(page))
|
||||
alpha = map(partial(has_alpha_channel, doc), xrefs)
|
||||
alpha = ({Info.ALPHA: a} for a in alpha)
|
||||
# alpha = map(dict, zip(repeat(Info.ALPHA), alpha))
|
||||
metadata = starmap(merge, zip(alpha, metadata))
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def load_image_handle_from_xref(doc, xref):
|
||||
return doc.extract_image(xref)
|
||||
|
||||
|
||||
rounder = rcompose(round, int)
|
||||
|
||||
|
||||
def get_page_metadata(page):
|
||||
@ -160,6 +156,26 @@ def get_page_metadata(page):
|
||||
}
|
||||
|
||||
|
||||
rounder = rcompose(round, int)
|
||||
def has_alpha_channel(doc, xref):
|
||||
|
||||
maybe_image = load_image_handle_from_xref(doc, xref)
|
||||
maybe_smask = maybe_image["smask"] if maybe_image else None
|
||||
|
||||
if maybe_smask:
|
||||
return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)])
|
||||
else:
|
||||
return bool(fitz.Pixmap(doc, xref).alpha)
|
||||
|
||||
|
||||
def tiny(metadata):
|
||||
return metadata[Info.WIDTH] * metadata[Info.HEIGHT] <= 4
|
||||
|
||||
|
||||
def clear_caches():
|
||||
get_image_infos.cache_clear()
|
||||
load_image_handle_from_xref.cache_clear()
|
||||
get_images_on_page.cache_clear()
|
||||
xref_to_image.cache_clear()
|
||||
|
||||
|
||||
atexit.register(clear_caches)
|
||||
|
||||
@ -1,14 +1,7 @@
|
||||
from itertools import takewhile, starmap, islice, repeat
|
||||
from operator import truth
|
||||
|
||||
from funcy import iterate
|
||||
|
||||
|
||||
def chunk_iterable(iterable, chunk_size):
|
||||
return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size)))))
|
||||
from funcy import iterate, chunks
|
||||
|
||||
|
||||
def until(cond, func, *args, **kwargs):
|
||||
for a, b in chunk_iterable(iterate(func, *args, **kwargs), chunk_size=2):
|
||||
for a, b in chunks(2, iterate(func, *args, **kwargs)):
|
||||
if cond(a, b):
|
||||
return a
|
||||
|
||||
@ -1,6 +1,32 @@
|
||||
from funcy import rcompose
|
||||
import pytest
|
||||
from funcy import rcompose, chunks
|
||||
|
||||
|
||||
def test_rcompose():
|
||||
f = rcompose(lambda x: x ** 2, str, lambda x: x * 2)
|
||||
assert f(3) == "99"
|
||||
|
||||
|
||||
def test_chunk_iterable_exact_split():
|
||||
a, b = chunks(5, iter(range(10)))
|
||||
assert a == list(range(5))
|
||||
assert b == list(range(5, 10))
|
||||
|
||||
|
||||
def test_chunk_iterable_no_split():
|
||||
a = next(chunks(10, iter(range(10))))
|
||||
assert a == list(range(10))
|
||||
|
||||
|
||||
def test_chunk_iterable_last_partial():
|
||||
a, b, c, d = chunks(3, iter(range(10)))
|
||||
assert d == [9]
|
||||
|
||||
|
||||
def test_chunk_iterable_empty():
|
||||
with pytest.raises(StopIteration):
|
||||
next(chunks(3, iter(range(0))))
|
||||
|
||||
|
||||
def test_chunk_iterable_less_than_chunk_size_elements():
|
||||
assert next(chunks(5, iter(range(2)))) == [0, 1]
|
||||
|
||||
@ -1,34 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from image_prediction.utils.generic import chunk_iterable
|
||||
|
||||
|
||||
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
|
||||
def test_predict(image_classifier, images, batch_of_expected_string_labels):
|
||||
predictions = list(image_classifier.predict(images))
|
||||
assert predictions == batch_of_expected_string_labels
|
||||
|
||||
|
||||
def test_chunk_iterable_exact_split():
|
||||
a, b = chunk_iterable(range(10), chunk_size=5)
|
||||
assert a == tuple(range(5))
|
||||
assert b == tuple(range(5, 10))
|
||||
|
||||
|
||||
def test_chunk_iterable_no_split():
|
||||
a = next(chunk_iterable(range(10), chunk_size=10))
|
||||
assert a == tuple(range(10))
|
||||
|
||||
|
||||
def test_chunk_iterable_last_partial():
|
||||
a, b, c, d = chunk_iterable(range(10), chunk_size=3)
|
||||
assert d == (9,)
|
||||
|
||||
|
||||
def test_chunk_iterable_empty():
|
||||
with pytest.raises(StopIteration):
|
||||
next(chunk_iterable(range(0), chunk_size=3))
|
||||
|
||||
|
||||
def test_chunk_iterable_less_than_chunk_size_elements():
|
||||
assert next(chunk_iterable(range(2), chunk_size=5)) == (0, 1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user