applied black
This commit is contained in:
parent
c372529ee5
commit
da9b3d0cb9
@ -1,4 +1,5 @@
|
||||
import random
|
||||
import random
|
||||
import string
|
||||
import tempfile
|
||||
from functools import partial
|
||||
@ -15,8 +16,12 @@ from image_prediction.classifier.classifier import Classifier
|
||||
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
|
||||
from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownDatabaseType, \
|
||||
UnknownLabelFormat
|
||||
from image_prediction.exceptions import (
|
||||
UnknownEstimatorAdapter,
|
||||
UnknownImageExtractor,
|
||||
UnknownDatabaseType,
|
||||
UnknownLabelFormat,
|
||||
)
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
||||
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||
@ -87,7 +92,7 @@ def label_format(request):
|
||||
|
||||
@pytest.fixture
|
||||
def expected_predictions_mapped(
|
||||
label_format, batch_of_expected_string_labels, batch_of_expected_label_to_probability_mappings
|
||||
label_format, batch_of_expected_string_labels, batch_of_expected_label_to_probability_mappings
|
||||
):
|
||||
if label_format == "index":
|
||||
return batch_of_expected_string_labels
|
||||
@ -98,9 +103,7 @@ def expected_predictions_mapped(
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_predictions(
|
||||
label_format, batch_of_expected_numeric_labels, batch_of_expected_probability_arrays
|
||||
):
|
||||
def expected_predictions(label_format, batch_of_expected_numeric_labels, batch_of_expected_probability_arrays):
|
||||
if label_format == "index":
|
||||
return batch_of_expected_numeric_labels
|
||||
elif label_format == "probability":
|
||||
@ -111,7 +114,7 @@ def expected_predictions(
|
||||
|
||||
@pytest.fixture
|
||||
def estimator_adapter(
|
||||
estimator_type, estimator_mock, keras_model, model_handle_mock, output_batch_generator, monkeypatch
|
||||
estimator_type, estimator_mock, keras_model, model_handle_mock, output_batch_generator, monkeypatch
|
||||
):
|
||||
if estimator_type == "mock":
|
||||
estimator_adapter = EstimatorAdapter(estimator_mock)
|
||||
@ -227,13 +230,14 @@ def map_labels(numeric_labels, classes):
|
||||
|
||||
@pytest.fixture
|
||||
def metadata_plus_mapped_prediction(expected_predictions_mapped, metadata):
|
||||
return [{"classification": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped, metadata)]
|
||||
return [{"classification": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped, metadata)]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metadata_formatted_plus_mapped_prediction_formatted(expected_predictions_mapped_and_formatted, metadata_formatted):
|
||||
return [
|
||||
{"classification": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped_and_formatted, metadata_formatted)
|
||||
{"classification": epm, **mdt}
|
||||
for epm, mdt in zip(expected_predictions_mapped_and_formatted, metadata_formatted)
|
||||
]
|
||||
|
||||
|
||||
@ -330,7 +334,6 @@ def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair):
|
||||
@pytest.fixture
|
||||
def model():
|
||||
class Model:
|
||||
|
||||
@staticmethod
|
||||
def predict(*args):
|
||||
return True
|
||||
@ -377,12 +380,14 @@ def model_loader(database_connector):
|
||||
@pytest.fixture
|
||||
def mlflow_run_id():
|
||||
from image_prediction.config import CONFIG
|
||||
|
||||
return CONFIG.service.run_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mlruns_dir():
|
||||
from image_prediction.locations import MLRUNS_DIR
|
||||
|
||||
return MLRUNS_DIR
|
||||
|
||||
|
||||
@ -394,7 +399,6 @@ def mlflow_reader(mlruns_dir):
|
||||
@pytest.fixture
|
||||
def model_handle_mock(estimator_mock):
|
||||
class ModelHandleMock:
|
||||
|
||||
def __init__(self):
|
||||
self.model = estimator_mock
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ def test_index_label_mapper(batch_of_expected_numeric_labels, batch_of_expected_
|
||||
|
||||
|
||||
def test_array_label_mapper(
|
||||
batch_of_expected_probability_arrays, batch_of_expected_label_to_probability_mappings, classes
|
||||
batch_of_expected_probability_arrays, batch_of_expected_label_to_probability_mappings, classes
|
||||
):
|
||||
mapper = ProbabilityMapper(classes)
|
||||
assert list(mapper(batch_of_expected_probability_arrays)) == batch_of_expected_label_to_probability_mappings
|
||||
|
||||
@ -18,4 +18,4 @@ def test_load_model_and_classes_from_mlflow_store(model_loader, mlflow_run_id):
|
||||
classes_loaded = model_loader.load_classes(mlflow_run_id)
|
||||
|
||||
assert type(model_loaded) == PredictionModelHandle
|
||||
assert classes_loaded == ['formula', 'logo', 'other', 'signature']
|
||||
assert classes_loaded == ["formula", "logo", "other", "signature"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user