diff --git a/image_prediction/image_extractor/extractors/parsable.py b/image_prediction/image_extractor/extractors/parsable.py index 3f3f21e..ad3655f 100644 --- a/image_prediction/image_extractor/extractors/parsable.py +++ b/image_prediction/image_extractor/extractors/parsable.py @@ -146,22 +146,24 @@ def xref_to_image(doc, xref) -> Union[Image.Image, None]: # NOTE: image extraction is done via pixmap to array, as this method is twice as fast as extraction via bytestream try: pixmap = fitz.Pixmap(doc, xref) - array = np.frombuffer(pixmap.samples, dtype=np.uint8).reshape(pixmap.h, pixmap.w, pixmap.n) - array = normalize_channels(array) + array = convert_pixmap_to_array(pixmap) return Image.fromarray(array) except ValueError: logger.debug(f"Xref {xref} is invalid, skipping extraction ...") return -def normalize_channels(array: np.ndarray): - if not array.ndim == 3: - array = np.expand_dims(array, axis=-1) +def convert_pixmap_to_array(pixmap: fitz.fitz.Pixmap): + array = np.frombuffer(pixmap.samples, dtype=np.uint8).reshape(pixmap.h, pixmap.w, pixmap.n) + array = _normalize_channels(array) + return array - if array.shape[-1] == 4: + +def _normalize_channels(array: np.ndarray): + if array.shape[-1] == 1: + array = array[:, :, 0] + elif array.shape[-1] == 4: array = array[..., :3] - elif array.shape[-1] == 1: - array = np.concatenate([array, array, array], axis=-1) elif array.shape[-1] != 3: logger.warning(f"Unexpected image format: {array.shape}.") raise ValueError(f"Unexpected image format: {array.shape}.")