made classifier accept tupls of images in addition to np.arrays; added pipeline (wip)

This commit is contained in:
Matthias Bisping 2022-03-29 22:00:34 +02:00
parent 3339ed2eab
commit ade318c7b7
4 changed files with 43 additions and 10 deletions

View File

@ -1,6 +1,7 @@
from typing import Mapping, List
from typing import Mapping, List, Union, Tuple
import numpy as np
from PIL.Image import Image
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.utils import get_logger
@ -21,9 +22,9 @@ class Classifier:
self.__estimator_adapter = estimator_adapter
self._classes = classes
def predict(self, batch: np.array) -> List[str]:
def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]:
if batch.shape[0] == 0:
if not isinstance(batch, tuple) and batch.shape[0] == 0:
return []
return [self._classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)]

View File

@ -5,14 +5,16 @@ from operator import itemgetter
import fitz
from PIL import Image
from funcy import rcompose
from tqdm import tqdm
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
from image_prediction.info import Info
class ParsablePDFImageExtractor(ImageExtractor):
def __init__(self):
def __init__(self, verbose=False):
self.doc: fitz.fitz.Document = None
self.verbose = verbose
def __process_images_on_page(self, page: fitz.fitz.Page):
def load_image_from_xref(xref):
@ -46,5 +48,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
def extract(self, pdf: bytes):
self.doc = fitz.Document(stream=pdf)
image_metadata_pairs = chain(*map(self.__process_images_on_page, self.doc))
image_metadata_pairs = chain(
*map(self.__process_images_on_page, tqdm(self.doc, desc="Extracting", disable=not self.verbose))
)
return image_metadata_pairs

View File

@ -1,16 +1,45 @@
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
from image_prediction.classifier.classifier import Classifier
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.config import CONFIG
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
from image_prediction.locations import MLRUNS_DIR
from image_prediction.model_loader.loader import ModelLoader
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
from image_prediction.redai_adapter.mlflow import MlflowModelReader
def get_image_classifier():
model_loader = ModelLoader(MlflowConnector(MlflowModelReader(MLRUNS_DIR)))
model = model_loader.load_model(CONFIG.service.run_id)
classes = model_loader.load_classes(CONFIG.service.run_id)
classifier = Classifier(EstimatorAdapter(model), classes)
image_classifier = ImageClassifier(classifier)
return image_classifier
def get_extractor():
image_extractor = ParsablePDFImageExtractor(verbose=True)
return image_extractor
def get_extractor_classifier():
extractor_classifier = ExtractorClassifier(get_extractor(), get_image_classifier())
return extractor_classifier
class Pipeline:
def __init__(self):
model_loader = ModelLoader(MlflowConnector(MlflowModelReader(MLRUNS_DIR)))
model = model_loader.load_model(CONFIG.service.run_id)
classes = model_loader.load_classes(CONFIG.service.run_id)
classifier = Classifier(EstimatorAdapter(model), classes)
self.pipeline = get_extractor_classifier()
def __call__(self, pdf: bytes):
return self.pipeline(pdf)

View File

@ -3,7 +3,6 @@ import pytest
from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor
from image_prediction.extraction import extract_images_from_pdf
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
@pytest.mark.parametrize("extractor_type", ["mock"])