refactoring pipeline WIP
This commit is contained in:
parent
120721f5f1
commit
f036ee55e6
@ -24,10 +24,14 @@ class Classifier:
|
||||
self.__pipe = rcompose(self.__estimator_adapter, self.__label_mapper)
|
||||
|
||||
def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]:
|
||||
if not isinstance(batch, tuple) and batch.shape[0] == 0:
|
||||
# TODO: necessary?
|
||||
if not batch:
|
||||
return []
|
||||
|
||||
return list(self.__pipe(batch))
|
||||
if isinstance(batch, np.ndarray) and batch.shape[0] == 0:
|
||||
return []
|
||||
|
||||
return list(self.__pipe(batch)) # TODO: list?
|
||||
|
||||
def __call__(self, batch: np.array) -> List[str]:
|
||||
logger.debug("Classifier.predict")
|
||||
|
||||
@ -35,17 +35,17 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
||||
|
||||
pages = extract_pages(self.doc, page_range) if page_range else self.doc
|
||||
|
||||
pages = self.__maybe_show_progress(pages, page_range)
|
||||
# pages = self.__maybe_show_progress(pages, page_range)
|
||||
|
||||
image_metadata_pairs = chain.from_iterable(map(self.__process_images_on_page, pages))
|
||||
|
||||
yield from image_metadata_pairs
|
||||
|
||||
def __maybe_show_progress(self, iterable, page_range):
|
||||
return self.__progressbar(page_range)(iterable) if self.verbose else iterable
|
||||
|
||||
def __progressbar(self, page_range):
|
||||
return rpartial(tqdm, desc=self.progress_message, total=len(page_range) if page_range else None)
|
||||
# def __maybe_show_progress(self, iterable, page_range):
|
||||
# return self.__progressbar(page_range)(iterable) if self.verbose else iterable
|
||||
#
|
||||
# def __progressbar(self, page_range):
|
||||
# return rpartial(tqdm, desc=self.progress_message, total=len(page_range) if page_range else None)
|
||||
|
||||
def __process_images_on_page(self, page: fitz.fitz.Page):
|
||||
images = get_images_on_page(self.doc, page)
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
import os
|
||||
from itertools import chain
|
||||
|
||||
from funcy import rcompose
|
||||
from funcy import rcompose, juxt, first, compose, second, chunks, curry
|
||||
|
||||
from image_prediction.config import CONFIG
|
||||
from image_prediction.default_objects import get_extractor_classifier, get_formatter, get_mlflow_model_loader
|
||||
from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor
|
||||
from image_prediction.locations import MLRUNS_DIR
|
||||
from image_prediction.utils.generic import lift, starlift
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
|
||||
@ -19,8 +21,37 @@ def load_pipeline(**kwargs):
|
||||
|
||||
|
||||
class Pipeline:
|
||||
def __init__(self, model_loader, model_identifier, **kwargs):
|
||||
self.pipe = rcompose(get_extractor_classifier(model_loader, model_identifier, **kwargs), get_formatter())
|
||||
def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs):
|
||||
extractor = get_extractor(**kwargs)
|
||||
batcher = lambda x: chunks(batch_size, x)
|
||||
classifier = get_image_classifier(model_loader, model_identifier)
|
||||
merger = lambda predictions, metadata: ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata))
|
||||
formatter = get_formatter()
|
||||
|
||||
left = compose(classifier, lift(first))
|
||||
right = lift(second)
|
||||
|
||||
# --------
|
||||
# -- -- -- --
|
||||
# == == == ==
|
||||
# -- -- -- --
|
||||
# --------
|
||||
# --------
|
||||
|
||||
def inspect(x):
|
||||
import IPython
|
||||
IPython.embed()
|
||||
return x
|
||||
|
||||
self.pipe = rcompose(
|
||||
extractor,
|
||||
batcher,
|
||||
lift(list),
|
||||
lift(juxt(left, right)),
|
||||
starlift(merger),
|
||||
chain.from_iterable,
|
||||
formatter,
|
||||
)
|
||||
|
||||
def __call__(self, pdf: bytes, page_range: range = None):
|
||||
yield from self.pipe(pdf, page_range=page_range)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from itertools import starmap
|
||||
|
||||
from funcy import iterate, first, curry, map
|
||||
|
||||
|
||||
@ -7,3 +9,7 @@ def until(cond, func, *args, **kwargs):
|
||||
|
||||
def lift(fn):
|
||||
return curry(map)(fn)
|
||||
|
||||
|
||||
def starlift(fn):
|
||||
return curry(starmap)(fn)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user