made classifier accept tupls of images in addition to np.arrays; added pipeline (wip)
This commit is contained in:
parent
3339ed2eab
commit
ade318c7b7
@ -1,6 +1,7 @@
|
||||
from typing import Mapping, List
|
||||
from typing import Mapping, List, Union, Tuple
|
||||
|
||||
import numpy as np
|
||||
from PIL.Image import Image
|
||||
|
||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||
from image_prediction.utils import get_logger
|
||||
@ -21,9 +22,9 @@ class Classifier:
|
||||
self.__estimator_adapter = estimator_adapter
|
||||
self._classes = classes
|
||||
|
||||
def predict(self, batch: np.array) -> List[str]:
|
||||
def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]:
|
||||
|
||||
if batch.shape[0] == 0:
|
||||
if not isinstance(batch, tuple) and batch.shape[0] == 0:
|
||||
return []
|
||||
|
||||
return [self._classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)]
|
||||
|
||||
@ -5,14 +5,16 @@ from operator import itemgetter
|
||||
import fitz
|
||||
from PIL import Image
|
||||
from funcy import rcompose
|
||||
from tqdm import tqdm
|
||||
|
||||
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
|
||||
from image_prediction.info import Info
|
||||
|
||||
|
||||
class ParsablePDFImageExtractor(ImageExtractor):
|
||||
def __init__(self):
|
||||
def __init__(self, verbose=False):
|
||||
self.doc: fitz.fitz.Document = None
|
||||
self.verbose = verbose
|
||||
|
||||
def __process_images_on_page(self, page: fitz.fitz.Page):
|
||||
def load_image_from_xref(xref):
|
||||
@ -46,5 +48,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
||||
|
||||
def extract(self, pdf: bytes):
|
||||
self.doc = fitz.Document(stream=pdf)
|
||||
image_metadata_pairs = chain(*map(self.__process_images_on_page, self.doc))
|
||||
image_metadata_pairs = chain(
|
||||
*map(self.__process_images_on_page, tqdm(self.doc, desc="Extracting", disable=not self.verbose))
|
||||
)
|
||||
return image_metadata_pairs
|
||||
|
||||
@ -1,16 +1,45 @@
|
||||
import os
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
|
||||
from image_prediction.classifier.classifier import Classifier
|
||||
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||
from image_prediction.config import CONFIG
|
||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||
from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier
|
||||
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||
from image_prediction.locations import MLRUNS_DIR
|
||||
from image_prediction.model_loader.loader import ModelLoader
|
||||
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
|
||||
from image_prediction.redai_adapter.mlflow import MlflowModelReader
|
||||
|
||||
|
||||
def get_image_classifier():
|
||||
model_loader = ModelLoader(MlflowConnector(MlflowModelReader(MLRUNS_DIR)))
|
||||
model = model_loader.load_model(CONFIG.service.run_id)
|
||||
classes = model_loader.load_classes(CONFIG.service.run_id)
|
||||
classifier = Classifier(EstimatorAdapter(model), classes)
|
||||
image_classifier = ImageClassifier(classifier)
|
||||
|
||||
return image_classifier
|
||||
|
||||
|
||||
def get_extractor():
|
||||
image_extractor = ParsablePDFImageExtractor(verbose=True)
|
||||
|
||||
return image_extractor
|
||||
|
||||
|
||||
def get_extractor_classifier():
|
||||
extractor_classifier = ExtractorClassifier(get_extractor(), get_image_classifier())
|
||||
|
||||
return extractor_classifier
|
||||
|
||||
|
||||
class Pipeline:
|
||||
|
||||
def __init__(self):
|
||||
model_loader = ModelLoader(MlflowConnector(MlflowModelReader(MLRUNS_DIR)))
|
||||
model = model_loader.load_model(CONFIG.service.run_id)
|
||||
classes = model_loader.load_classes(CONFIG.service.run_id)
|
||||
classifier = Classifier(EstimatorAdapter(model), classes)
|
||||
self.pipeline = get_extractor_classifier()
|
||||
|
||||
def __call__(self, pdf: bytes):
|
||||
return self.pipeline(pdf)
|
||||
|
||||
@ -3,7 +3,6 @@ import pytest
|
||||
|
||||
from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor
|
||||
from image_prediction.extraction import extract_images_from_pdf
|
||||
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||
|
||||
|
||||
@pytest.mark.parametrize("extractor_type", ["mock"])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user