25 lines
690 B
Python
25 lines
690 B
Python
import pytest
|
|
|
|
from image_prediction.exceptions import UnknownLabelFormat
|
|
from image_prediction.label_mapper.mappers.numeric import IndexMapper
|
|
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
|
|
|
|
|
|
@pytest.fixture
|
|
def label_mapper(label_format, classes):
|
|
if label_format == "index":
|
|
return IndexMapper(classes)
|
|
elif label_format == "probability":
|
|
return ProbabilityMapper(classes)
|
|
else:
|
|
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
|
|
|
|
|
|
@pytest.fixture(params=["index"])
|
|
def label_format(request):
|
|
return request.param
|
|
|
|
|
|
@pytest.fixture
|
|
def classes():
|
|
return ["A", "B", "C"] |