replaced string keys for metadata fields with enum members

This commit is contained in:
Matthias Bisping 2022-03-29 20:29:44 +02:00
parent 358d7ecd91
commit 7340fb6dda
6 changed files with 41 additions and 22 deletions

View File

@ -7,6 +7,7 @@ from PIL import Image
from funcy import rcompose from funcy import rcompose
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
from image_prediction.info import Info
class ParsablePDFImageExtractor(ImageExtractor): class ParsablePDFImageExtractor(ImageExtractor):
@ -21,15 +22,15 @@ class ParsablePDFImageExtractor(ImageExtractor):
x1, y1, x2, y2 = map(rounder, image_info["bbox"]) x1, y1, x2, y2 = map(rounder, image_info["bbox"])
width, height = itemgetter("width", "height")(image_info) width, height = itemgetter("width", "height")(image_info)
return { return {
"page_width": page_width, Info.PAGE_WIDTH: page_width,
"page_height": page_height, Info.PAGE_HEIGHT: page_height,
"page_idx": page.number, Info.PAGE_IDX: page.number,
"width": width, Info.WIDTH: width,
"height": height, Info.HEIGHT: height,
"x1": x1, Info.X1: x1,
"x2": x2, Info.X2: x2,
"y1": y1, Info.Y1: y1,
"y2": y2, Info.Y2: y2,
} }
rounder = rcompose(round, int) rounder = rcompose(round, int)

13
image_prediction/info.py Normal file
View File

@ -0,0 +1,13 @@
from enum import Enum
class Info(Enum):
PAGE_WIDTH = "page_width"
PAGE_HEIGHT = "page_height"
PAGE_IDX = "page_idx"
WIDTH = "width"
HEIGHT = "height"
X1 = "x1"
X2 = "x2"
Y1 = "y1"
Y2 = "y2"

View File

@ -5,4 +5,4 @@ class DatabaseConnector(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def get_object(self, identifier): def get_object(self, identifier):
pass raise NotImplementedError

View File

@ -2,3 +2,4 @@
norecursedirs = incl norecursedirs = incl
filterwarnings = filterwarnings =
ignore:.*imp.*:DeprecationWarning ignore:.*imp.*:DeprecationWarning
ignore:.*Use.*:DeprecationWarning

View File

@ -18,6 +18,7 @@ from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExt
from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
from image_prediction.info import Info
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
from image_prediction.model_loader.loader import ModelLoader from image_prediction.model_loader.loader import ModelLoader
from image_prediction.model_loader.loaders.mlflow import MlflowConnector from image_prediction.model_loader.loaders.mlflow import MlflowConnector
@ -30,6 +31,8 @@ def image_extractor(extractor_type):
return ImageExtractorMock() return ImageExtractorMock()
elif extractor_type == "parsable_pdf": elif extractor_type == "parsable_pdf":
return ParsablePDFImageExtractor() return ParsablePDFImageExtractor()
elif extractor_type == "default":
return None
else: else:
raise UnknownImageExtractor(f"No image extractor for type {extractor_type} was specified.") raise UnknownImageExtractor(f"No image extractor for type {extractor_type} was specified.")
@ -149,15 +152,15 @@ def metadata(images):
y1 = random.randint(0, page_height - height) y1 = random.randint(0, page_height - height)
y2 = y1 + height y2 = y1 + height
metadata = { metadata = {
"page_width": page_width, Info.PAGE_WIDTH: page_width,
"page_height": page_height, Info.PAGE_HEIGHT: page_height,
"page_idx": current_page_idx(), Info.PAGE_IDX: current_page_idx(),
"width": width, Info.WIDTH: width,
"height": height, Info.HEIGHT: height,
"x1": x1, Info.X1: x1,
"x2": x2, Info.X2: x2,
"y1": y1, Info.Y1: y1,
"y2": y2 Info.Y2: y2,
} }
return metadata return metadata
@ -180,7 +183,7 @@ def pdf(image_metadata_pairs):
def add_image(pdf, image_metadata_pair): def add_image(pdf, image_metadata_pair):
while fewer_pages_then_required(image_metadata_pair.metadata["page_idx"], pdf): while fewer_pages_then_required(image_metadata_pair.metadata[Info.PAGE_IDX], pdf):
pdf.add_page() pdf.add_page()
add_image_to_last_page(pdf, image_metadata_pair) add_image_to_last_page(pdf, image_metadata_pair)
@ -196,7 +199,7 @@ def pdf_stream(pdf: fpdf.fpdf.FPDF):
def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair): def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair):
image, metadata = image_metadata_pair image, metadata = image_metadata_pair
x, y, w, h = itemgetter("x1", "y1", "width", "height")(metadata) x, y, w, h = itemgetter(Info.X1, Info.Y1, Info.WIDTH, Info.HEIGHT)(metadata)
with tempfile.NamedTemporaryFile(suffix=".png") as temp_image: with tempfile.NamedTemporaryFile(suffix=".png") as temp_image:
image.save(temp_image.name) image.save(temp_image.name)

View File

@ -3,6 +3,7 @@ import pytest
from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor
from image_prediction.extraction import extract_images_from_pdf from image_prediction.extraction import extract_images_from_pdf
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
@pytest.mark.parametrize("extractor_type", ["mock"]) @pytest.mark.parametrize("extractor_type", ["mock"])
@ -12,7 +13,7 @@ def test_image_extractor_mock(image_extractor, images):
assert images_extracted == images assert images_extracted == images
@pytest.mark.parametrize("extractor_type", ["parsable_pdf"]) @pytest.mark.parametrize("extractor_type", ["parsable_pdf", "default"])
@pytest.mark.parametrize("batch_size", [0, 1, 2, 64]) @pytest.mark.parametrize("batch_size", [0, 1, 2, 64])
@pytest.mark.parametrize("input_size", [{"depth": 3, "width": 170, "height": 220}], indirect=["input_size"]) @pytest.mark.parametrize("input_size", [{"depth": 3, "width": 170, "height": 220}], indirect=["input_size"])
def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, input_size): def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, input_size):