Pull request #8: Pipeline refactoring

Merge in RR/image-prediction from pipeline_refactoring to tdd_refactoring

Squashed commit of the following:

commit 6989fcb3313007b7eecf4bba39077fcde6924a9a
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Mon Apr 25 09:49:49 2022 +0200

    removed obsolete module

commit 7428aeee37b11c31cffa597c85b018ba71e79a1d
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Mon Apr 25 09:45:45 2022 +0200

    refactoring

commit 0dcd3894154fdf34bd3ba4ef816362434474f472
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Mon Apr 25 08:57:21 2022 +0200

    refactoring; removed obsolete extractor-classifier

commit 1078aa81144f4219149b3fcacdae8b09c4b905c0
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Fri Apr 22 17:18:10 2022 +0200

    removed obsolete imports

commit 71f61fc5fc915da3941cf5ed5d9cc90fccc49031
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Fri Apr 22 17:16:25 2022 +0200

    comment changed

commit b582726cd1de233edb55c5a76c91e99f9dd3bd13
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Fri Apr 22 17:12:11 2022 +0200

    refactoring

commit 8abc9010048078868b235d6793ac6c8b20abb985
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Thu Apr 21 21:25:47 2022 +0200

    formatting

commit 2c87c419fe3185a25c27139e7fcf79f60971ad24
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Thu Apr 21 21:24:05 2022 +0200

    formatting

commit 50b161192db43a84464125c6d79650225e1010d6
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Thu Apr 21 21:20:18 2022 +0200

    refactoring

commit 9a1446cccfa070852a5d9c0bdbc36037b82541fc
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Thu Apr 21 21:04:57 2022 +0200

    refactoring

commit 6c10b55ff8e61412cb2fe5a5625e660ecaf1d7d1
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Thu Apr 21 19:48:05 2022 +0200

    refactoring

commit 72e785e3e31c132ab352119e9921725f91fac9e2
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Thu Apr 21 19:43:39 2022 +0200

    refactoring

commit f036ee55e6747daf31e3929bdc2d93dc5f2a56ca
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Wed Apr 20 18:30:41 2022 +0200

    refactoring pipeline WIP
This commit is contained in:
Matthias Bisping 2022-04-25 10:08:49 +02:00
parent 120721f5f1
commit 26ef5fce8a
12 changed files with 53 additions and 79 deletions

View File

@ -24,10 +24,11 @@ class Classifier:
self.__pipe = rcompose(self.__estimator_adapter, self.__label_mapper) self.__pipe = rcompose(self.__estimator_adapter, self.__label_mapper)
def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]: def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]:
if not isinstance(batch, tuple) and batch.shape[0] == 0:
if isinstance(batch, np.ndarray) and batch.shape[0] == 0:
return [] return []
return list(self.__pipe(batch)) return self.__pipe(batch)
def __call__(self, batch: np.array) -> List[str]: def __call__(self, batch: np.array) -> List[str]:
logger.debug("Classifier.predict") logger.debug("Classifier.predict")

View File

@ -4,16 +4,15 @@ from image_prediction.classifier.classifier import Classifier
from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.compositor.compositor import TransformerCompositor from image_prediction.compositor.compositor import TransformerCompositor
from image_prediction.estimator.adapter.adapter import EstimatorAdapter from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier
from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter
from image_prediction.formatter.formatters.enum import EnumFormatter from image_prediction.formatter.formatters.enum import EnumFormatter
from image_prediction.transformer.transformers.response import ResponseTransformer
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
from image_prediction.model_loader.loader import ModelLoader from image_prediction.model_loader.loader import ModelLoader
from image_prediction.model_loader.loaders.mlflow import MlflowConnector from image_prediction.model_loader.loaders.mlflow import MlflowConnector
from image_prediction.redai_adapter.mlflow import MlflowModelReader from image_prediction.redai_adapter.mlflow import MlflowModelReader
from image_prediction.transformer.transformers.coordinate.pdfnet import PDFNetCoordinateTransformer from image_prediction.transformer.transformers.coordinate.pdfnet import PDFNetCoordinateTransformer
from image_prediction.transformer.transformers.response import ResponseTransformer
def get_mlflow_model_loader(mlruns_dir): def get_mlflow_model_loader(mlruns_dir):
@ -32,14 +31,6 @@ def get_extractor(**kwargs):
return image_extractor return image_extractor
def get_extractor_classifier(model_loader, model_identifier, **kwargs):
extractor_classifier = ExtractorClassifier(
get_extractor(**kwargs), get_image_classifier(model_loader, model_identifier)
)
return extractor_classifier
def get_formatter(): def get_formatter():
formatter = TransformerCompositor( formatter = TransformerCompositor(
PDFNetCoordinateTransformer(), EnumFormatter(), ResponseTransformer(), Snake2CamelCaseKeyFormatter() PDFNetCoordinateTransformer(), EnumFormatter(), ResponseTransformer(), Snake2CamelCaseKeyFormatter()

View File

@ -1,33 +0,0 @@
from itertools import chain
from typing import Iterable
from funcy import chunks, rpartial
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.image_extractor.extractor import ImageExtractor
class ExtractorClassifier:
"""This class is responsible for orchestrating the pairing of classifications and image metadata. It extracts images
from an object and classifies them. Then it ties the classification together with the metadata. It returns an
iterable of dictionaries, where each dictionary has a field 'label' for the classification and possibly additional
fields for metadata -- metadata could be void.
"""
def __init__(self, image_extractor: ImageExtractor, image_classifier: ImageClassifier):
self.classifier = image_classifier
self.extractor = image_extractor
def __process_batch(self, batch, batch_size):
images, metadata = zip(*batch)
predictions = self.classifier(images, batch_size)
responses = ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata))
return responses
def __call__(self, obj, batch_size=16, **kwargs) -> Iterable[dict]:
image_metadata_pairs = self.extractor(obj, **kwargs)
batches = chunks(batch_size, image_metadata_pairs)
predictions = chain.from_iterable(map(rpartial(self.__process_batch, batch_size), batches))
yield from predictions

View File

@ -7,8 +7,7 @@ from typing import List
import fitz import fitz
from PIL import Image from PIL import Image
from funcy import rcompose, merge, pluck, curry, compose, rpartial from funcy import rcompose, merge, pluck, curry, compose
from tqdm import tqdm
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
from image_prediction.info import Info from image_prediction.info import Info
@ -35,18 +34,10 @@ class ParsablePDFImageExtractor(ImageExtractor):
pages = extract_pages(self.doc, page_range) if page_range else self.doc pages = extract_pages(self.doc, page_range) if page_range else self.doc
pages = self.__maybe_show_progress(pages, page_range)
image_metadata_pairs = chain.from_iterable(map(self.__process_images_on_page, pages)) image_metadata_pairs = chain.from_iterable(map(self.__process_images_on_page, pages))
yield from image_metadata_pairs yield from image_metadata_pairs
def __maybe_show_progress(self, iterable, page_range):
return self.__progressbar(page_range)(iterable) if self.verbose else iterable
def __progressbar(self, page_range):
return rpartial(tqdm, desc=self.progress_message, total=len(page_range) if page_range else None)
def __process_images_on_page(self, page: fitz.fitz.Page): def __process_images_on_page(self, page: fitz.fitz.Page):
images = get_images_on_page(self.doc, page) images = get_images_on_page(self.doc, page)
metadata = get_metadata_for_images_on_page(self.doc, page) metadata = get_metadata_for_images_on_page(self.doc, page)

View File

@ -1,10 +1,14 @@
import os import os
from functools import partial
from itertools import chain, tee
from funcy import rcompose from funcy import rcompose, first, compose, second, chunks, identity
from tqdm import tqdm
from image_prediction.config import CONFIG from image_prediction.config import CONFIG
from image_prediction.default_objects import get_extractor_classifier, get_formatter, get_mlflow_model_loader from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor
from image_prediction.locations import MLRUNS_DIR from image_prediction.locations import MLRUNS_DIR
from image_prediction.utils.generic import lift, starlift
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@ -18,9 +22,37 @@ def load_pipeline(**kwargs):
return pipeline return pipeline
def parallel(*fs):
return lambda *args: (f(a) for f, a in zip(fs, args))
def star(f):
return lambda x: f(*x)
class Pipeline: class Pipeline:
def __init__(self, model_loader, model_identifier, **kwargs): def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs):
self.pipe = rcompose(get_extractor_classifier(model_loader, model_identifier, **kwargs), get_formatter())
extract = get_extractor(**kwargs)
classifier = get_image_classifier(model_loader, model_identifier)
reformat = get_formatter()
split = compose(star(parallel(*map(lift, (first, second)))), tee)
classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size))
pairwise_apply = compose(star, parallel)
join = compose(starlift(lambda prd, mdt: {"classification": prd, **mdt}), star(zip))
# +>--classify--v
# --extract-->--split--| |--join-->reformat
# +>--identity--^
self.pipe = rcompose(
extract, # ... image-metadata-pairs as a stream
split, # ... into an image stream and a metadata stream
pairwise_apply(classify, identity), # ... apply functions to the streams pairwise
join, # ... the streams by zipping
reformat, # ... the items
)
def __call__(self, pdf: bytes, page_range: range = None): def __call__(self, pdf: bytes, page_range: range = None):
yield from self.pipe(pdf, page_range=page_range) yield from tqdm(self.pipe(pdf, page_range=page_range), desc="Processing images from document", unit=" images")

View File

@ -1,3 +1,5 @@
from itertools import starmap
from funcy import iterate, first, curry, map from funcy import iterate, first, curry, map
@ -7,3 +9,7 @@ def until(cond, func, *args, **kwargs):
def lift(fn): def lift(fn):
return curry(map)(fn) return curry(map)(fn)
def starlift(fn):
return curry(starmap)(fn)

View File

@ -14,4 +14,4 @@ def image_extractor(extractor_type):
elif extractor_type == "default": elif extractor_type == "default":
return None return None
else: else:
raise UnknownImageExtractor(f"No image extractor for type {extractor_type} was specified.") raise UnknownImageExtractor(f"No image extractor for type {extractor_type} was specified.")

View File

@ -4,7 +4,7 @@ import pytest
@pytest.mark.parametrize("estimator_type", ["mock", "keras", "redai"]) @pytest.mark.parametrize("estimator_type", ["mock", "keras", "redai"])
@pytest.mark.parametrize("label_format", ["index", "probability"]) @pytest.mark.parametrize("label_format", ["index", "probability"])
def test_classifier(classifier, input_batch, expected_predictions_mapped): def test_classifier(classifier, input_batch, expected_predictions_mapped):
predictions = classifier(input_batch) predictions = list(classifier(input_batch))
assert predictions == expected_predictions_mapped assert predictions == expected_predictions_mapped

View File

@ -1,14 +0,0 @@
from operator import itemgetter
import pytest
from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier
@pytest.mark.parametrize("extractor_type", ["mock"])
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
def test_extractor_classifier(image_extractor, image_classifier, images, batch_of_expected_string_labels):
extractor_classifier = ExtractorClassifier(image_extractor, image_classifier)
results = extractor_classifier(images)
labels = list(map(itemgetter("classification"), results))
assert labels == batch_of_expected_string_labels

View File

@ -11,8 +11,8 @@ from image_prediction.extraction import extract_images_from_pdf
from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.image_extractor.extractors.parsable import extract_pages, get_image_infos, has_alpha_channel from image_prediction.image_extractor.extractors.parsable import extract_pages, get_image_infos, has_alpha_channel
from image_prediction.info import Info from image_prediction.info import Info
from test.utils.comparison import metadata_equal, image_sets_equal
from test.utils.generation.pdf import add_image, pdf_stream from test.utils.generation.pdf import add_image, pdf_stream
from test.utils.comparison import images_equal, metadata_equal, image_sets_equal
@pytest.mark.parametrize("extractor_type", ["mock"]) @pytest.mark.parametrize("extractor_type", ["mock"])

View File

@ -37,12 +37,12 @@ def test_server_predict_failure(client, mute_logger):
def test_server_health_check(client): def test_server_health_check(client):
response = client.get("/ready") response = client.get("/health")
assert response.status_code == 200 assert response.status_code == 200
assert response.json == "OK" assert response.json == "OK"
def test_server_ready_check(client): def test_server_ready_check(client):
response = client.get("/health") response = client.get("/ready")
assert response.status_code == 200 assert response.status_code == 200
assert response.json == "OK" assert response.json == "OK"