applied black

This commit is contained in:
Matthias Bisping 2022-04-01 19:50:44 +02:00
parent c372529ee5
commit da9b3d0cb9
3 changed files with 17 additions and 13 deletions

View File

@ -1,4 +1,5 @@
import random
import random
import string
import tempfile
from functools import partial
@ -15,8 +16,12 @@ from image_prediction.classifier.classifier import Classifier
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownDatabaseType, \
UnknownLabelFormat
from image_prediction.exceptions import (
UnknownEstimatorAdapter,
UnknownImageExtractor,
UnknownDatabaseType,
UnknownLabelFormat,
)
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
@ -87,7 +92,7 @@ def label_format(request):
@pytest.fixture
def expected_predictions_mapped(
label_format, batch_of_expected_string_labels, batch_of_expected_label_to_probability_mappings
label_format, batch_of_expected_string_labels, batch_of_expected_label_to_probability_mappings
):
if label_format == "index":
return batch_of_expected_string_labels
@ -98,9 +103,7 @@ def expected_predictions_mapped(
@pytest.fixture
def expected_predictions(
label_format, batch_of_expected_numeric_labels, batch_of_expected_probability_arrays
):
def expected_predictions(label_format, batch_of_expected_numeric_labels, batch_of_expected_probability_arrays):
if label_format == "index":
return batch_of_expected_numeric_labels
elif label_format == "probability":
@ -111,7 +114,7 @@ def expected_predictions(
@pytest.fixture
def estimator_adapter(
estimator_type, estimator_mock, keras_model, model_handle_mock, output_batch_generator, monkeypatch
estimator_type, estimator_mock, keras_model, model_handle_mock, output_batch_generator, monkeypatch
):
if estimator_type == "mock":
estimator_adapter = EstimatorAdapter(estimator_mock)
@ -227,13 +230,14 @@ def map_labels(numeric_labels, classes):
@pytest.fixture
def metadata_plus_mapped_prediction(expected_predictions_mapped, metadata):
return [{"classification": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped, metadata)]
return [{"classification": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped, metadata)]
@pytest.fixture
def metadata_formatted_plus_mapped_prediction_formatted(expected_predictions_mapped_and_formatted, metadata_formatted):
return [
{"classification": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped_and_formatted, metadata_formatted)
{"classification": epm, **mdt}
for epm, mdt in zip(expected_predictions_mapped_and_formatted, metadata_formatted)
]
@ -330,7 +334,6 @@ def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair):
@pytest.fixture
def model():
class Model:
@staticmethod
def predict(*args):
return True
@ -377,12 +380,14 @@ def model_loader(database_connector):
@pytest.fixture
def mlflow_run_id():
from image_prediction.config import CONFIG
return CONFIG.service.run_id
@pytest.fixture
def mlruns_dir():
from image_prediction.locations import MLRUNS_DIR
return MLRUNS_DIR
@ -394,7 +399,6 @@ def mlflow_reader(mlruns_dir):
@pytest.fixture
def model_handle_mock(estimator_mock):
class ModelHandleMock:
def __init__(self):
self.model = estimator_mock

View File

@ -13,7 +13,7 @@ def test_index_label_mapper(batch_of_expected_numeric_labels, batch_of_expected_
def test_array_label_mapper(
batch_of_expected_probability_arrays, batch_of_expected_label_to_probability_mappings, classes
batch_of_expected_probability_arrays, batch_of_expected_label_to_probability_mappings, classes
):
mapper = ProbabilityMapper(classes)
assert list(mapper(batch_of_expected_probability_arrays)) == batch_of_expected_label_to_probability_mappings

View File

@ -18,4 +18,4 @@ def test_load_model_and_classes_from_mlflow_store(model_loader, mlflow_run_id):
classes_loaded = model_loader.load_classes(mlflow_run_id)
assert type(model_loaded) == PredictionModelHandle
assert classes_loaded == ['formula', 'logo', 'other', 'signature']
assert classes_loaded == ["formula", "logo", "other", "signature"]