refactoring; removed obsolete extractor-classifier

This commit is contained in:
Matthias Bisping 2022-04-25 08:57:21 +02:00
parent 1078aa8114
commit 0dcd389415
8 changed files with 7 additions and 73 deletions

View File

@ -24,9 +24,6 @@ class Classifier:
self.__pipe = rcompose(self.__estimator_adapter, self.__label_mapper)
def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]:
# TODO: necessary?
if not batch:
return []
if isinstance(batch, np.ndarray) and batch.shape[0] == 0:
return []

View File

@ -4,16 +4,15 @@ from image_prediction.classifier.classifier import Classifier
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.compositor.compositor import TransformerCompositor
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier
from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter
from image_prediction.formatter.formatters.enum import EnumFormatter
from image_prediction.transformer.transformers.response import ResponseTransformer
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
from image_prediction.model_loader.loader import ModelLoader
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
from image_prediction.redai_adapter.mlflow import MlflowModelReader
from image_prediction.transformer.transformers.coordinate.pdfnet import PDFNetCoordinateTransformer
from image_prediction.transformer.transformers.response import ResponseTransformer
def get_mlflow_model_loader(mlruns_dir):
@ -32,14 +31,6 @@ def get_extractor(**kwargs):
return image_extractor
def get_extractor_classifier(model_loader, model_identifier, **kwargs):
extractor_classifier = ExtractorClassifier(
get_extractor(**kwargs), get_image_classifier(model_loader, model_identifier)
)
return extractor_classifier
def get_formatter():
formatter = TransformerCompositor(
PDFNetCoordinateTransformer(), EnumFormatter(), ResponseTransformer(), Snake2CamelCaseKeyFormatter()

View File

@ -1,33 +0,0 @@
from itertools import chain
from typing import Iterable
from funcy import chunks, rpartial
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.image_extractor.extractor import ImageExtractor
class ExtractorClassifier:
"""This class is responsible for orchestrating the pairing of classifications and image metadata. It extracts images
from an object and classifies them. Then it ties the classification together with the metadata. It returns an
iterable of dictionaries, where each dictionary has a field 'label' for the classification and possibly additional
fields for metadata -- metadata could be void.
"""
def __init__(self, image_extractor: ImageExtractor, image_classifier: ImageClassifier):
self.classifier = image_classifier
self.extractor = image_extractor
def __process_batch(self, batch, batch_size):
images, metadata = zip(*batch)
predictions = self.classifier(images, batch_size)
responses = ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata))
return responses
def __call__(self, obj, batch_size=16, **kwargs) -> Iterable[dict]:
image_metadata_pairs = self.extractor(obj, **kwargs)
batches = chunks(batch_size, image_metadata_pairs)
predictions = chain.from_iterable(map(rpartial(self.__process_batch, batch_size), batches))
yield from predictions

View File

@ -7,8 +7,7 @@ from typing import List
import fitz
from PIL import Image
from funcy import rcompose, merge, pluck, curry, compose, rpartial
from tqdm import tqdm
from funcy import rcompose, merge, pluck, curry, compose
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
from image_prediction.info import Info
@ -35,18 +34,10 @@ class ParsablePDFImageExtractor(ImageExtractor):
pages = extract_pages(self.doc, page_range) if page_range else self.doc
# pages = self.__maybe_show_progress(pages, page_range)
image_metadata_pairs = chain.from_iterable(map(self.__process_images_on_page, pages))
yield from image_metadata_pairs
# def __maybe_show_progress(self, iterable, page_range):
# return self.__progressbar(page_range)(iterable) if self.verbose else iterable
#
# def __progressbar(self, page_range):
# return rpartial(tqdm, desc=self.progress_message, total=len(page_range) if page_range else None)
def __process_images_on_page(self, page: fitz.fitz.Page):
images = get_images_on_page(self.doc, page)
metadata = get_metadata_for_images_on_page(self.doc, page)

View File

@ -3,6 +3,7 @@ from functools import partial
from itertools import chain, tee
from funcy import rcompose, first, compose, second, chunks, identity
from tqdm import tqdm
from image_prediction.config import CONFIG
from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor
@ -31,6 +32,7 @@ def star(f):
class Pipeline:
def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs):
extract = get_extractor(**kwargs)
classifier = get_image_classifier(model_loader, model_identifier)
reformat = get_formatter()
@ -53,4 +55,4 @@ class Pipeline:
)
def __call__(self, pdf: bytes, page_range: range = None):
yield from self.pipe(pdf, page_range=page_range)
yield from tqdm(self.pipe(pdf, page_range=page_range))

View File

@ -14,4 +14,4 @@ def image_extractor(extractor_type):
elif extractor_type == "default":
return None
else:
raise UnknownImageExtractor(f"No image extractor for type {extractor_type} was specified.")
raise UnknownImageExtractor(f"No image extractor for type {extractor_type} was specified.")

View File

@ -1,14 +0,0 @@
from operator import itemgetter
import pytest
from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier
@pytest.mark.parametrize("extractor_type", ["mock"])
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
def test_extractor_classifier(image_extractor, image_classifier, images, batch_of_expected_string_labels):
extractor_classifier = ExtractorClassifier(image_extractor, image_classifier)
results = extractor_classifier(images)
labels = list(map(itemgetter("classification"), results))
assert labels == batch_of_expected_string_labels

View File

@ -11,8 +11,8 @@ from image_prediction.extraction import extract_images_from_pdf
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.image_extractor.extractors.parsable import extract_pages, get_image_infos, has_alpha_channel
from image_prediction.info import Info
from test.utils.comparison import metadata_equal, image_sets_equal
from test.utils.generation.pdf import add_image, pdf_stream
from test.utils.comparison import images_equal, metadata_equal, image_sets_equal
@pytest.mark.parametrize("extractor_type", ["mock"])