Compare commits
5 Commits
master
...
RED-5107-r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
265c61df1a | ||
|
|
7c6f9809bc | ||
|
|
6c54cea57d | ||
|
|
37a7e0a0e7 | ||
|
|
c03913e088 |
@ -1,6 +1,6 @@
|
||||
from enum import Enum
|
||||
from operator import itemgetter
|
||||
from typing import Mapping, Iterable
|
||||
from typing import Mapping, Iterable, Union
|
||||
|
||||
import numpy as np
|
||||
from funcy import rcompose, rpartial
|
||||
@ -27,7 +27,10 @@ class ProbabilityMapper(LabelMapper):
|
||||
f"Received fewer probabilities ({len(probabilities)}) than labels were passed ({len(self.__labels)})."
|
||||
)
|
||||
|
||||
def __map_array(self, probabilities: np.ndarray) -> dict:
|
||||
def __map_array(self, probabilities: Union[np.ndarray, None]) -> Union[dict, None]:
|
||||
if not isinstance(probabilities, np.ndarray) and not probabilities:
|
||||
return None
|
||||
|
||||
self.__validate_array_label_format(probabilities)
|
||||
cls2prob = dict(
|
||||
sorted(zip(self.__labels, list(map(self.__rounder, probabilities))), key=itemgetter(1), reverse=True)
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from funcy import rcompose
|
||||
|
||||
from image_prediction.utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
@ -9,10 +7,13 @@ class PredictionModelHandle:
|
||||
"""Simplifies usage of ModelHandle instances for prediction purposes."""
|
||||
|
||||
def __init__(self, model_handle):
|
||||
self.__predict = rcompose(model_handle.prep_images, model_handle.model.predict)
|
||||
self.__prep_images = model_handle.prep_images
|
||||
self.__predict = model_handle.model.predict
|
||||
|
||||
def predict(self, *args, **kwargs):
|
||||
return self.__predict(*args, **kwargs)
|
||||
tensor, valid_mask = self.__prep_images(*args, **kwargs)
|
||||
predictions = self.__predict(tensor, **kwargs)
|
||||
return [p if v else None for p, v in zip(predictions, valid_mask)]
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
logger.debug("PredictionModelHandle.predict")
|
||||
|
||||
@ -2,6 +2,11 @@ import abc
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from PIL import Image
|
||||
|
||||
from image_prediction.utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class ModelWrapper(abc.ABC):
|
||||
@ -30,12 +35,29 @@ class ModelWrapper(abc.ABC):
|
||||
def __resize_and_convert(self, image):
|
||||
return image.resize(self.input_shape[:-1]).convert("RGB")
|
||||
|
||||
def __monitored_resize_and_convert(self, image):
|
||||
# RED-5170: fails if image is 'broken'
|
||||
try:
|
||||
image, valid = self.__resize_and_convert(image), True
|
||||
except OSError as err:
|
||||
image, valid = self.__handle_resize_exception(err)
|
||||
except Exception as err:
|
||||
image, valid = self.__handle_resize_exception(err)
|
||||
|
||||
return image, valid
|
||||
|
||||
def __handle_resize_exception(self, err):
|
||||
logger.warn(f"{err}: couldn't resize image, replace and passthrough.")
|
||||
image = Image.new("RGB", self.input_shape[:-1])
|
||||
valid = False
|
||||
return image, valid
|
||||
|
||||
def prep_images(self, images):
|
||||
images = map(self.__resize_and_convert, images)
|
||||
images, valid_mask = zip(*map(self.__monitored_resize_and_convert, images))
|
||||
tensor = self.__images_to_tensor(images)
|
||||
tensor = self.__preprocess_tensor(tensor)
|
||||
|
||||
return tensor
|
||||
return tensor, valid_mask
|
||||
|
||||
@abc.abstractmethod
|
||||
def __build(self, base_weights=None) -> tf.keras.models.Model:
|
||||
|
||||
@ -3,15 +3,18 @@ from functools import reduce
|
||||
from typing import Iterable, Callable, List
|
||||
|
||||
from PIL import Image
|
||||
from funcy import juxt, first, rest, rcompose, rpartial, complement, ilen
|
||||
from funcy import juxt, first, rest, rcompose, rpartial
|
||||
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.info import Info
|
||||
from image_prediction.stitching.grouping import CoordGrouper
|
||||
from image_prediction.stitching.split_mapper import HorizontalSplitMapper, VerticalSplitMapper
|
||||
from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once, validate_box
|
||||
from image_prediction.utils import get_logger
|
||||
from image_prediction.utils.generic import until
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def make_merger_sentinel():
|
||||
def no_new_mergers(pairs):
|
||||
@ -184,6 +187,10 @@ def concat_images(im1: Image, im2: Image, metadata: dict, axis):
|
||||
|
||||
for im, offset in zip(images, offsets):
|
||||
box = (offset, 0) if not axis else (0, offset)
|
||||
im_aggr.paste(im, box=box)
|
||||
|
||||
try: # RED-5170: fails if image is 'broken'
|
||||
im_aggr.paste(im, box=box)
|
||||
except Exception as err:
|
||||
logger.warn(f"{err}: couldn't merge image, replace and passthrough. (page: {metadata[Info.PAGE_IDX]})")
|
||||
|
||||
return im_aggr
|
||||
|
||||
@ -35,10 +35,15 @@ def build_image_info(data: dict) -> dict:
|
||||
width / height > CONFIG.filters.image_width_to_height_quotient.max
|
||||
)
|
||||
|
||||
classification = data["classification"]
|
||||
# FIXME: pass in fallback value for classification and introduce new key for image validness
|
||||
classification = data["classification"] or "other"
|
||||
representation = data["representation"]
|
||||
|
||||
min_confidence_breached = bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence)
|
||||
min_confidence_breached = (
|
||||
bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence)
|
||||
if data["classification"]
|
||||
else True
|
||||
)
|
||||
|
||||
image_info = {
|
||||
"classification": classification,
|
||||
|
||||
2
test/fixtures/model.py
vendored
2
test/fixtures/model.py
vendored
@ -102,7 +102,7 @@ def model_handle_mock(estimator_mock):
|
||||
self.model = estimator_mock
|
||||
|
||||
def prep_images(self, batch):
|
||||
return [None for _ in batch]
|
||||
return [True for _ in batch], [None for _ in batch]
|
||||
|
||||
def predict(self, batch):
|
||||
return [None for _ in batch]
|
||||
|
||||
@ -5,7 +5,7 @@ import fitz
|
||||
import fpdf
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from funcy import first, rest
|
||||
from funcy import first, rest, lmap
|
||||
|
||||
from image_prediction.extraction import extract_images_from_pdf
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
@ -27,6 +27,7 @@ def test_image_extractor_mock(image_extractor, images):
|
||||
@pytest.mark.parametrize("alpha", [False, True])
|
||||
def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, input_size, alpha):
|
||||
images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor))
|
||||
|
||||
if not alpha:
|
||||
assert image_sets_equal(images_extracted, images)
|
||||
assert metadata_equal(metadata_extracted, metadata)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user