refactoring: estimator + model

This commit is contained in:
Matthias Bisping 2022-03-25 11:23:07 +01:00
parent eb18ae8719
commit ee959346b7
9 changed files with 69 additions and 101 deletions

View File

View File

@ -0,0 +1,14 @@
class EstimatorMock:
def __init__(self):
self.__output_batch = None
@property
def output_batch(self):
return self.__output_batch
@output_batch.setter
def output_batch(self, output_batch):
self.__output_batch = output_batch
def predict(self, batch):
return self.__output_batch

View File

View File

@ -0,0 +1,6 @@
from image_prediction.model.model import Model
class ModelMock(Model):
def __init__(self, estimator):
super().__init__(estimator=estimator)

View File

@ -0,0 +1,13 @@
import abc
class Model(abc.ABC):
def __init__(self, estimator):
self.__estimator = estimator
@property
def estimator(self):
return self.__estimator
def predict(self, batch):
return self.estimator.predict(batch)

View File

@ -1,70 +0,0 @@
import os.path
import pytest
from image_prediction.predictor import Predictor
@pytest.fixture
def predictions():
return [
{
"class": "signature",
"probabilities": {
"signature": 1.0,
"logo": 9.150285377746546e-19,
"other": 4.374506412383356e-19,
"formula": 3.582569597002796e-24,
},
}
]
@pytest.fixture
def metadata():
return [
{
"page_height": 612.0,
"page_width": 792.0,
"height": 61.049999999999955,
"width": 139.35000000000002,
"page_idx": 8,
"x1": 63.5,
"x2": 202.85000000000002,
"y1": 472.0,
"y2": 533.05,
}
]
@pytest.fixture
def response():
return [
{
"classification": {
"label": "signature",
"probabilities": {"formula": 0.0, "logo": 0.0, "other": 0.0, "signature": 1.0},
},
"filters": {
"allPassed": True,
"geometry": {
"imageFormat": {"quotient": 2.282555282555285, "tooTall": False, "tooWide": False},
"imageSize": {"quotient": 0.13248234868245012, "tooLarge": False, "tooSmall": False},
},
"probability": {"unconfident": False},
},
"geometry": {"height": 61.049999999999955, "width": 139.35000000000002},
"position": {"pageNumber": 9, "x1": 63.5, "x2": 202.85000000000002, "y1": 472.0, "y2": 533.05},
}
]
@pytest.fixture
def predictor():
return Predictor()
@pytest.fixture
def test_pdf():
with open("./test/test_data/f2dc689ca794fccb8cd38b95f2bf6ba9.pdf", "rb") as f:
return f.read()

View File

@ -0,0 +1,36 @@
import numpy as np
import pytest
from image_prediction.estimator.mock import EstimatorMock
from image_prediction.model.mock import ModelMock
@pytest.fixture(scope="session")
def estimator():
return EstimatorMock()
@pytest.fixture(scope="session")
def batches(batch_size):
input_batch = np.random.normal(size=(batch_size, 10, 15))
output_batch = np.random.randint(low=42, high=43, size=(batch_size, 10, 15))
return input_batch, output_batch
@pytest.fixture(scope="session")
def classes():
return ["A", "B", "C"]
@pytest.fixture(scope="session")
def model(model_type, estimator):
if model_type == "mock":
return ModelMock(estimator)
@pytest.mark.parametrize("model_type", ["mock"], scope="session")
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
def test_predict(model, batches):
input_batch, output_batch = batches
model.estimator.output_batch = output_batch
assert np.all(np.equal(model.predict(input_batch), output_batch))

View File

@ -1,26 +0,0 @@
def test_predict_pdf_works(predictor, test_pdf):
# FIXME ugly test since there are '\n's in the dict with unknown heritage
predictions, metadata = predictor.predict_pdf(test_pdf)
predictions = [p for p in predictions][0]
assert predictions["class"] == "formula"
probabilities = predictions["probabilities"]
# Floating point precision problem for output so test only that keys exist not the values
assert all(key in probabilities for key in ("formula", "other", "signature", "logo"))
metadata = list(metadata)
metadata = dict(**metadata[0])
metadata.pop("document_filename") # temp filename cannot be tested
assert metadata == {
"px_width": 389.0,
"px_height": 389.0,
"width": 194.49999000000003,
"height": 194.49998999999997,
"x1": 320.861,
"x2": 515.36099,
"y1": 347.699,
"y2": 542.19899,
"page_width": 595.2800000000001,
"page_height": 841.89,
"page_rotation": 0,
"page_idx": 1,
"n_pages": 3,
}

View File

@ -1,5 +0,0 @@
from image_prediction.response import build_response
def test_build_response_returns_valid_response(predictions, metadata, response):
assert build_response(predictions, metadata) == response