refactoring: estimator + model
This commit is contained in:
parent
eb18ae8719
commit
ee959346b7
0
image_prediction/estimator/__init__.py
Normal file
0
image_prediction/estimator/__init__.py
Normal file
14
image_prediction/estimator/mock.py
Normal file
14
image_prediction/estimator/mock.py
Normal file
@ -0,0 +1,14 @@
|
||||
class EstimatorMock:
|
||||
def __init__(self):
|
||||
self.__output_batch = None
|
||||
|
||||
@property
|
||||
def output_batch(self):
|
||||
return self.__output_batch
|
||||
|
||||
@output_batch.setter
|
||||
def output_batch(self, output_batch):
|
||||
self.__output_batch = output_batch
|
||||
|
||||
def predict(self, batch):
|
||||
return self.__output_batch
|
||||
0
image_prediction/model/__init__.py
Normal file
0
image_prediction/model/__init__.py
Normal file
6
image_prediction/model/mock.py
Normal file
6
image_prediction/model/mock.py
Normal file
@ -0,0 +1,6 @@
|
||||
from image_prediction.model.model import Model
|
||||
|
||||
|
||||
class ModelMock(Model):
|
||||
def __init__(self, estimator):
|
||||
super().__init__(estimator=estimator)
|
||||
13
image_prediction/model/model.py
Normal file
13
image_prediction/model/model.py
Normal file
@ -0,0 +1,13 @@
|
||||
import abc
|
||||
|
||||
|
||||
class Model(abc.ABC):
|
||||
def __init__(self, estimator):
|
||||
self.__estimator = estimator
|
||||
|
||||
@property
|
||||
def estimator(self):
|
||||
return self.__estimator
|
||||
|
||||
def predict(self, batch):
|
||||
return self.estimator.predict(batch)
|
||||
@ -1,70 +0,0 @@
|
||||
import os.path
|
||||
|
||||
import pytest
|
||||
|
||||
from image_prediction.predictor import Predictor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def predictions():
|
||||
return [
|
||||
{
|
||||
"class": "signature",
|
||||
"probabilities": {
|
||||
"signature": 1.0,
|
||||
"logo": 9.150285377746546e-19,
|
||||
"other": 4.374506412383356e-19,
|
||||
"formula": 3.582569597002796e-24,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metadata():
|
||||
return [
|
||||
{
|
||||
"page_height": 612.0,
|
||||
"page_width": 792.0,
|
||||
"height": 61.049999999999955,
|
||||
"width": 139.35000000000002,
|
||||
"page_idx": 8,
|
||||
"x1": 63.5,
|
||||
"x2": 202.85000000000002,
|
||||
"y1": 472.0,
|
||||
"y2": 533.05,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def response():
|
||||
return [
|
||||
{
|
||||
"classification": {
|
||||
"label": "signature",
|
||||
"probabilities": {"formula": 0.0, "logo": 0.0, "other": 0.0, "signature": 1.0},
|
||||
},
|
||||
"filters": {
|
||||
"allPassed": True,
|
||||
"geometry": {
|
||||
"imageFormat": {"quotient": 2.282555282555285, "tooTall": False, "tooWide": False},
|
||||
"imageSize": {"quotient": 0.13248234868245012, "tooLarge": False, "tooSmall": False},
|
||||
},
|
||||
"probability": {"unconfident": False},
|
||||
},
|
||||
"geometry": {"height": 61.049999999999955, "width": 139.35000000000002},
|
||||
"position": {"pageNumber": 9, "x1": 63.5, "x2": 202.85000000000002, "y1": 472.0, "y2": 533.05},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def predictor():
|
||||
return Predictor()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_pdf():
|
||||
with open("./test/test_data/f2dc689ca794fccb8cd38b95f2bf6ba9.pdf", "rb") as f:
|
||||
return f.read()
|
||||
36
test/unit_tests/model_test.py
Normal file
36
test/unit_tests/model_test.py
Normal file
@ -0,0 +1,36 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from image_prediction.estimator.mock import EstimatorMock
|
||||
from image_prediction.model.mock import ModelMock
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def estimator():
|
||||
return EstimatorMock()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def batches(batch_size):
|
||||
input_batch = np.random.normal(size=(batch_size, 10, 15))
|
||||
output_batch = np.random.randint(low=42, high=43, size=(batch_size, 10, 15))
|
||||
return input_batch, output_batch
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def classes():
|
||||
return ["A", "B", "C"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def model(model_type, estimator):
|
||||
if model_type == "mock":
|
||||
return ModelMock(estimator)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_type", ["mock"], scope="session")
|
||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
|
||||
def test_predict(model, batches):
|
||||
input_batch, output_batch = batches
|
||||
model.estimator.output_batch = output_batch
|
||||
assert np.all(np.equal(model.predict(input_batch), output_batch))
|
||||
@ -1,26 +0,0 @@
|
||||
def test_predict_pdf_works(predictor, test_pdf):
|
||||
# FIXME ugly test since there are '\n's in the dict with unknown heritage
|
||||
predictions, metadata = predictor.predict_pdf(test_pdf)
|
||||
predictions = [p for p in predictions][0]
|
||||
assert predictions["class"] == "formula"
|
||||
probabilities = predictions["probabilities"]
|
||||
# Floating point precision problem for output so test only that keys exist not the values
|
||||
assert all(key in probabilities for key in ("formula", "other", "signature", "logo"))
|
||||
metadata = list(metadata)
|
||||
metadata = dict(**metadata[0])
|
||||
metadata.pop("document_filename") # temp filename cannot be tested
|
||||
assert metadata == {
|
||||
"px_width": 389.0,
|
||||
"px_height": 389.0,
|
||||
"width": 194.49999000000003,
|
||||
"height": 194.49998999999997,
|
||||
"x1": 320.861,
|
||||
"x2": 515.36099,
|
||||
"y1": 347.699,
|
||||
"y2": 542.19899,
|
||||
"page_width": 595.2800000000001,
|
||||
"page_height": 841.89,
|
||||
"page_rotation": 0,
|
||||
"page_idx": 1,
|
||||
"n_pages": 3,
|
||||
}
|
||||
@ -1,5 +0,0 @@
|
||||
from image_prediction.response import build_response
|
||||
|
||||
|
||||
def test_build_response_returns_valid_response(predictions, metadata, response):
|
||||
assert build_response(predictions, metadata) == response
|
||||
Loading…
x
Reference in New Issue
Block a user