added response formatter and pipeline test
This commit is contained in:
parent
82added50a
commit
5caa9807e2
@ -11,10 +11,10 @@ def get_mlflow_model_loader(mlruns_dir):
|
|||||||
return model_loader
|
return model_loader
|
||||||
|
|
||||||
|
|
||||||
def load_pipeline():
|
def load_pipeline(**kwargs):
|
||||||
model_loader = get_mlflow_model_loader(MLRUNS_DIR)
|
model_loader = get_mlflow_model_loader(MLRUNS_DIR)
|
||||||
model_identifier = CONFIG.service.run_id
|
model_identifier = CONFIG.service.run_id
|
||||||
|
|
||||||
pipeline = Pipeline(model_loader, model_identifier)
|
pipeline = Pipeline(model_loader, model_identifier, **kwargs)
|
||||||
|
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|||||||
@ -18,7 +18,7 @@ class ExtractorClassifier:
|
|||||||
images, metadata = zip(*batch)
|
images, metadata = zip(*batch)
|
||||||
|
|
||||||
predictions = self.classifier(images)
|
predictions = self.classifier(images)
|
||||||
responses = ({"prediction": prd, **mdt} for prd, mdt in zip(predictions, metadata))
|
responses = ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata))
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
def __call__(self, obj) -> Iterable[ImageMetadataPair]:
|
def __call__(self, obj) -> Iterable[ImageMetadataPair]:
|
||||||
|
|||||||
72
image_prediction/formatter/formatters/response.py
Normal file
72
image_prediction/formatter/formatters/response.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
import math
|
||||||
|
from operator import itemgetter
|
||||||
|
|
||||||
|
from image_prediction.config import CONFIG
|
||||||
|
from image_prediction.transformer.transformer import Transformer
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseTransformer(Transformer):
|
||||||
|
|
||||||
|
def transform(self, data):
|
||||||
|
try:
|
||||||
|
return build_image_info(data)
|
||||||
|
except TypeError:
|
||||||
|
return map(build_image_info, data)
|
||||||
|
|
||||||
|
|
||||||
|
def build_image_info(data: dict) -> dict:
|
||||||
|
def compute_geometric_quotient():
|
||||||
|
page_area_sqrt = math.sqrt(abs(page_width * page_height))
|
||||||
|
image_area_sqrt = math.sqrt(abs(x2 - x1) * abs(y2 - y1))
|
||||||
|
return image_area_sqrt / page_area_sqrt
|
||||||
|
|
||||||
|
page_width, page_height, x1, x2, y1, y2, width, height = itemgetter(
|
||||||
|
"page_width", "page_height", "x1", "x2", "y1", "y2", "width", "height"
|
||||||
|
)(data)
|
||||||
|
|
||||||
|
quotient = round(compute_geometric_quotient(), 4)
|
||||||
|
|
||||||
|
min_image_to_page_quotient_breached = bool(quotient < CONFIG.filters.image_to_page_quotient.min)
|
||||||
|
max_image_to_page_quotient_breached = bool(quotient > CONFIG.filters.image_to_page_quotient.max)
|
||||||
|
min_image_width_to_height_quotient_breached = bool(
|
||||||
|
width / height < CONFIG.filters.image_width_to_height_quotient.min
|
||||||
|
)
|
||||||
|
max_image_width_to_height_quotient_breached = bool(
|
||||||
|
width / height > CONFIG.filters.image_width_to_height_quotient.max
|
||||||
|
)
|
||||||
|
|
||||||
|
classification = data["classification"]
|
||||||
|
|
||||||
|
min_confidence_breached = bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence)
|
||||||
|
|
||||||
|
image_info = {
|
||||||
|
"classification": classification,
|
||||||
|
"position": {"x1": x1, "x2": x2, "y1": y1, "y2": y2, "pageNumber": data["page_idx"] + 1},
|
||||||
|
"geometry": {"width": width, "height": height},
|
||||||
|
"filters": {
|
||||||
|
"geometry": {
|
||||||
|
"imageSize": {
|
||||||
|
"quotient": quotient,
|
||||||
|
"tooLarge": max_image_to_page_quotient_breached,
|
||||||
|
"tooSmall": min_image_to_page_quotient_breached,
|
||||||
|
},
|
||||||
|
"imageFormat": {
|
||||||
|
"quotient": round(width / height, 4),
|
||||||
|
"tooTall": min_image_width_to_height_quotient_breached,
|
||||||
|
"tooWide": max_image_width_to_height_quotient_breached,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"probability": {"unconfident": min_confidence_breached},
|
||||||
|
"allPassed": not any(
|
||||||
|
[
|
||||||
|
max_image_to_page_quotient_breached,
|
||||||
|
min_image_to_page_quotient_breached,
|
||||||
|
min_image_width_to_height_quotient_breached,
|
||||||
|
max_image_width_to_height_quotient_breached,
|
||||||
|
min_confidence_breached,
|
||||||
|
]
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return image_info
|
||||||
@ -7,3 +7,5 @@ CONFIG_FILE = path.join(PACKAGE_ROOT_DIR, "config.yaml")
|
|||||||
|
|
||||||
DATA_DIR = path.join(PACKAGE_ROOT_DIR, "data")
|
DATA_DIR = path.join(PACKAGE_ROOT_DIR, "data")
|
||||||
MLRUNS_DIR = path.join(DATA_DIR, "mlruns")
|
MLRUNS_DIR = path.join(DATA_DIR, "mlruns")
|
||||||
|
|
||||||
|
TEST_DATA_DIR = path.join(PACKAGE_ROOT_DIR, "test", "data")
|
||||||
|
|||||||
@ -4,9 +4,12 @@ from funcy import rcompose, juxt
|
|||||||
|
|
||||||
from image_prediction.classifier.classifier import Classifier
|
from image_prediction.classifier.classifier import Classifier
|
||||||
from image_prediction.classifier.image_classifier import ImageClassifier
|
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||||
|
from image_prediction.compositor.compositor import TransformerCompositor
|
||||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||||
from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier
|
from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier
|
||||||
|
from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter
|
||||||
from image_prediction.formatter.formatters.enum import EnumFormatter
|
from image_prediction.formatter.formatters.enum import EnumFormatter
|
||||||
|
from image_prediction.formatter.formatters.response import ResponseTransformer
|
||||||
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||||
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
|
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
|
||||||
|
|
||||||
@ -18,27 +21,29 @@ def get_image_classifier(model_loader, model_identifier):
|
|||||||
return ImageClassifier(Classifier(EstimatorAdapter(model), ProbabilityMapper(classes)))
|
return ImageClassifier(Classifier(EstimatorAdapter(model), ProbabilityMapper(classes)))
|
||||||
|
|
||||||
|
|
||||||
def get_extractor():
|
def get_extractor(**kwargs):
|
||||||
image_extractor = ParsablePDFImageExtractor(verbose=True)
|
image_extractor = ParsablePDFImageExtractor(**kwargs)
|
||||||
|
|
||||||
return image_extractor
|
return image_extractor
|
||||||
|
|
||||||
|
|
||||||
def get_extractor_classifier(model_loader, model_identifier):
|
def get_extractor_classifier(model_loader, model_identifier, **kwargs):
|
||||||
extractor_classifier = ExtractorClassifier(get_extractor(), get_image_classifier(model_loader, model_identifier))
|
extractor_classifier = ExtractorClassifier(
|
||||||
|
get_extractor(**kwargs), get_image_classifier(model_loader, model_identifier)
|
||||||
|
)
|
||||||
|
|
||||||
return extractor_classifier
|
return extractor_classifier
|
||||||
|
|
||||||
|
|
||||||
def get_formatter():
|
def get_formatter():
|
||||||
formatter = EnumFormatter()
|
formatter = TransformerCompositor(EnumFormatter(), ResponseTransformer(), Snake2CamelCaseKeyFormatter())
|
||||||
|
|
||||||
return formatter
|
return formatter
|
||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
def __init__(self, model_loader, model_identifier):
|
def __init__(self, model_loader, model_identifier, **kwargs):
|
||||||
self.pipe = rcompose(get_extractor_classifier(model_loader, model_identifier), get_formatter())
|
self.pipe = rcompose(get_extractor_classifier(model_loader, model_identifier, **kwargs), get_formatter())
|
||||||
|
|
||||||
def __call__(self, pdf: bytes):
|
def __call__(self, pdf: bytes):
|
||||||
yield from self.pipe(pdf)
|
yield from self.pipe(pdf)
|
||||||
|
|||||||
@ -19,8 +19,11 @@ def main(args):
|
|||||||
with open(args.pdf, "rb") as f:
|
with open(args.pdf, "rb") as f:
|
||||||
predictions = pipeline(f.read())
|
predictions = pipeline(f.read())
|
||||||
|
|
||||||
|
with open("/tmp/f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json", "w") as f:
|
||||||
|
json.dump(list(predictions), f, indent=2)
|
||||||
|
|
||||||
for prd in predictions:
|
for prd in predictions:
|
||||||
print(json.dumps(prd, indent=1))
|
print(json.dumps(prd, indent=2))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
42
test/data/f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json
Normal file
42
test/data/f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"classification": {
|
||||||
|
"label": "formula",
|
||||||
|
"probabilities": {
|
||||||
|
"formula": 1.0,
|
||||||
|
"logo": 0.0,
|
||||||
|
"other": 0.0,
|
||||||
|
"signature": 0.0
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"position": {
|
||||||
|
"x1": 321,
|
||||||
|
"x2": 515,
|
||||||
|
"y1": 300,
|
||||||
|
"y2": 494,
|
||||||
|
"pageNumber": 2
|
||||||
|
},
|
||||||
|
"geometry": {
|
||||||
|
"width": 389,
|
||||||
|
"height": 389
|
||||||
|
},
|
||||||
|
"filters": {
|
||||||
|
"geometry": {
|
||||||
|
"imageSize": {
|
||||||
|
"quotient": 0.2741,
|
||||||
|
"tooLarge": false,
|
||||||
|
"tooSmall": false
|
||||||
|
},
|
||||||
|
"imageFormat": {
|
||||||
|
"quotient": 1.0,
|
||||||
|
"tooTall": false,
|
||||||
|
"tooWide": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"probability": {
|
||||||
|
"unconfident": false
|
||||||
|
},
|
||||||
|
"allPassed": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
@ -227,13 +227,13 @@ def map_labels(numeric_labels, classes):
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def metadata_plus_mapped_prediction(expected_predictions_mapped, metadata):
|
def metadata_plus_mapped_prediction(expected_predictions_mapped, metadata):
|
||||||
return [{"prediction": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped, metadata)]
|
return [{"classification": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped, metadata)]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def metadata_formatted_plus_mapped_prediction_formatted(expected_predictions_mapped_and_formatted, metadata_formatted):
|
def metadata_formatted_plus_mapped_prediction_formatted(expected_predictions_mapped_and_formatted, metadata_formatted):
|
||||||
return [
|
return [
|
||||||
{"prediction": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped_and_formatted, metadata_formatted)
|
{"classification": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped_and_formatted, metadata_formatted)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -10,5 +10,5 @@ from image_prediction.extractor_classifier.extractor_classifier import Extractor
|
|||||||
def test_extractor_classifier(image_extractor, image_classifier, images, batch_of_expected_string_labels):
|
def test_extractor_classifier(image_extractor, image_classifier, images, batch_of_expected_string_labels):
|
||||||
extractor_classifier = ExtractorClassifier(image_extractor, image_classifier)
|
extractor_classifier = ExtractorClassifier(image_extractor, image_classifier)
|
||||||
results = extractor_classifier(images)
|
results = extractor_classifier(images)
|
||||||
labels = list(map(itemgetter("prediction"), results))
|
labels = list(map(itemgetter("classification"), results))
|
||||||
assert labels == batch_of_expected_string_labels
|
assert labels == batch_of_expected_string_labels
|
||||||
|
|||||||
18
test/unit_tests/pipeline_test.py
Normal file
18
test/unit_tests/pipeline_test.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from image_prediction.default_objects import load_pipeline
|
||||||
|
from image_prediction.locations import TEST_DATA_DIR
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline():
|
||||||
|
|
||||||
|
pipeline = load_pipeline(verbose=False)
|
||||||
|
|
||||||
|
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf"), "rb") as f:
|
||||||
|
predictions = list(pipeline(f.read()))
|
||||||
|
|
||||||
|
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json"), "r") as f:
|
||||||
|
expectations = json.load(f)
|
||||||
|
|
||||||
|
assert predictions == expectations
|
||||||
Loading…
x
Reference in New Issue
Block a user