Compare commits
5 Commits
master
...
RED-5107-r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
265c61df1a | ||
|
|
7c6f9809bc | ||
|
|
6c54cea57d | ||
|
|
37a7e0a0e7 | ||
|
|
c03913e088 |
@ -1,6 +1,6 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import Mapping, Iterable
|
from typing import Mapping, Iterable, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from funcy import rcompose, rpartial
|
from funcy import rcompose, rpartial
|
||||||
@ -27,7 +27,10 @@ class ProbabilityMapper(LabelMapper):
|
|||||||
f"Received fewer probabilities ({len(probabilities)}) than labels were passed ({len(self.__labels)})."
|
f"Received fewer probabilities ({len(probabilities)}) than labels were passed ({len(self.__labels)})."
|
||||||
)
|
)
|
||||||
|
|
||||||
def __map_array(self, probabilities: np.ndarray) -> dict:
|
def __map_array(self, probabilities: Union[np.ndarray, None]) -> Union[dict, None]:
|
||||||
|
if not isinstance(probabilities, np.ndarray) and not probabilities:
|
||||||
|
return None
|
||||||
|
|
||||||
self.__validate_array_label_format(probabilities)
|
self.__validate_array_label_format(probabilities)
|
||||||
cls2prob = dict(
|
cls2prob = dict(
|
||||||
sorted(zip(self.__labels, list(map(self.__rounder, probabilities))), key=itemgetter(1), reverse=True)
|
sorted(zip(self.__labels, list(map(self.__rounder, probabilities))), key=itemgetter(1), reverse=True)
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
from funcy import rcompose
|
|
||||||
|
|
||||||
from image_prediction.utils import get_logger
|
from image_prediction.utils import get_logger
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
@ -9,10 +7,13 @@ class PredictionModelHandle:
|
|||||||
"""Simplifies usage of ModelHandle instances for prediction purposes."""
|
"""Simplifies usage of ModelHandle instances for prediction purposes."""
|
||||||
|
|
||||||
def __init__(self, model_handle):
|
def __init__(self, model_handle):
|
||||||
self.__predict = rcompose(model_handle.prep_images, model_handle.model.predict)
|
self.__prep_images = model_handle.prep_images
|
||||||
|
self.__predict = model_handle.model.predict
|
||||||
|
|
||||||
def predict(self, *args, **kwargs):
|
def predict(self, *args, **kwargs):
|
||||||
return self.__predict(*args, **kwargs)
|
tensor, valid_mask = self.__prep_images(*args, **kwargs)
|
||||||
|
predictions = self.__predict(tensor, **kwargs)
|
||||||
|
return [p if v else None for p, v in zip(predictions, valid_mask)]
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
logger.debug("PredictionModelHandle.predict")
|
logger.debug("PredictionModelHandle.predict")
|
||||||
|
|||||||
@ -2,6 +2,11 @@ import abc
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from image_prediction.utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
class ModelWrapper(abc.ABC):
|
class ModelWrapper(abc.ABC):
|
||||||
@ -30,12 +35,29 @@ class ModelWrapper(abc.ABC):
|
|||||||
def __resize_and_convert(self, image):
|
def __resize_and_convert(self, image):
|
||||||
return image.resize(self.input_shape[:-1]).convert("RGB")
|
return image.resize(self.input_shape[:-1]).convert("RGB")
|
||||||
|
|
||||||
|
def __monitored_resize_and_convert(self, image):
|
||||||
|
# RED-5170: fails if image is 'broken'
|
||||||
|
try:
|
||||||
|
image, valid = self.__resize_and_convert(image), True
|
||||||
|
except OSError as err:
|
||||||
|
image, valid = self.__handle_resize_exception(err)
|
||||||
|
except Exception as err:
|
||||||
|
image, valid = self.__handle_resize_exception(err)
|
||||||
|
|
||||||
|
return image, valid
|
||||||
|
|
||||||
|
def __handle_resize_exception(self, err):
|
||||||
|
logger.warn(f"{err}: couldn't resize image, replace and passthrough.")
|
||||||
|
image = Image.new("RGB", self.input_shape[:-1])
|
||||||
|
valid = False
|
||||||
|
return image, valid
|
||||||
|
|
||||||
def prep_images(self, images):
|
def prep_images(self, images):
|
||||||
images = map(self.__resize_and_convert, images)
|
images, valid_mask = zip(*map(self.__monitored_resize_and_convert, images))
|
||||||
tensor = self.__images_to_tensor(images)
|
tensor = self.__images_to_tensor(images)
|
||||||
tensor = self.__preprocess_tensor(tensor)
|
tensor = self.__preprocess_tensor(tensor)
|
||||||
|
|
||||||
return tensor
|
return tensor, valid_mask
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def __build(self, base_weights=None) -> tf.keras.models.Model:
|
def __build(self, base_weights=None) -> tf.keras.models.Model:
|
||||||
|
|||||||
@ -3,15 +3,18 @@ from functools import reduce
|
|||||||
from typing import Iterable, Callable, List
|
from typing import Iterable, Callable, List
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from funcy import juxt, first, rest, rcompose, rpartial, complement, ilen
|
from funcy import juxt, first, rest, rcompose, rpartial
|
||||||
|
|
||||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||||
from image_prediction.info import Info
|
from image_prediction.info import Info
|
||||||
from image_prediction.stitching.grouping import CoordGrouper
|
from image_prediction.stitching.grouping import CoordGrouper
|
||||||
from image_prediction.stitching.split_mapper import HorizontalSplitMapper, VerticalSplitMapper
|
from image_prediction.stitching.split_mapper import HorizontalSplitMapper, VerticalSplitMapper
|
||||||
from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once, validate_box
|
from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once, validate_box
|
||||||
|
from image_prediction.utils import get_logger
|
||||||
from image_prediction.utils.generic import until
|
from image_prediction.utils.generic import until
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
def make_merger_sentinel():
|
def make_merger_sentinel():
|
||||||
def no_new_mergers(pairs):
|
def no_new_mergers(pairs):
|
||||||
@ -184,6 +187,10 @@ def concat_images(im1: Image, im2: Image, metadata: dict, axis):
|
|||||||
|
|
||||||
for im, offset in zip(images, offsets):
|
for im, offset in zip(images, offsets):
|
||||||
box = (offset, 0) if not axis else (0, offset)
|
box = (offset, 0) if not axis else (0, offset)
|
||||||
im_aggr.paste(im, box=box)
|
|
||||||
|
try: # RED-5170: fails if image is 'broken'
|
||||||
|
im_aggr.paste(im, box=box)
|
||||||
|
except Exception as err:
|
||||||
|
logger.warn(f"{err}: couldn't merge image, replace and passthrough. (page: {metadata[Info.PAGE_IDX]})")
|
||||||
|
|
||||||
return im_aggr
|
return im_aggr
|
||||||
|
|||||||
@ -35,10 +35,15 @@ def build_image_info(data: dict) -> dict:
|
|||||||
width / height > CONFIG.filters.image_width_to_height_quotient.max
|
width / height > CONFIG.filters.image_width_to_height_quotient.max
|
||||||
)
|
)
|
||||||
|
|
||||||
classification = data["classification"]
|
# FIXME: pass in fallback value for classification and introduce new key for image validness
|
||||||
|
classification = data["classification"] or "other"
|
||||||
representation = data["representation"]
|
representation = data["representation"]
|
||||||
|
|
||||||
min_confidence_breached = bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence)
|
min_confidence_breached = (
|
||||||
|
bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence)
|
||||||
|
if data["classification"]
|
||||||
|
else True
|
||||||
|
)
|
||||||
|
|
||||||
image_info = {
|
image_info = {
|
||||||
"classification": classification,
|
"classification": classification,
|
||||||
|
|||||||
2
test/fixtures/model.py
vendored
2
test/fixtures/model.py
vendored
@ -102,7 +102,7 @@ def model_handle_mock(estimator_mock):
|
|||||||
self.model = estimator_mock
|
self.model = estimator_mock
|
||||||
|
|
||||||
def prep_images(self, batch):
|
def prep_images(self, batch):
|
||||||
return [None for _ in batch]
|
return [True for _ in batch], [None for _ in batch]
|
||||||
|
|
||||||
def predict(self, batch):
|
def predict(self, batch):
|
||||||
return [None for _ in batch]
|
return [None for _ in batch]
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import fitz
|
|||||||
import fpdf
|
import fpdf
|
||||||
import pytest
|
import pytest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from funcy import first, rest
|
from funcy import first, rest, lmap
|
||||||
|
|
||||||
from image_prediction.extraction import extract_images_from_pdf
|
from image_prediction.extraction import extract_images_from_pdf
|
||||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||||
@ -27,6 +27,7 @@ def test_image_extractor_mock(image_extractor, images):
|
|||||||
@pytest.mark.parametrize("alpha", [False, True])
|
@pytest.mark.parametrize("alpha", [False, True])
|
||||||
def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, input_size, alpha):
|
def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, input_size, alpha):
|
||||||
images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor))
|
images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor))
|
||||||
|
|
||||||
if not alpha:
|
if not alpha:
|
||||||
assert image_sets_equal(images_extracted, images)
|
assert image_sets_equal(images_extracted, images)
|
||||||
assert metadata_equal(metadata_extracted, metadata)
|
assert metadata_equal(metadata_extracted, metadata)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user