refactoring

This commit is contained in:
Matthias Bisping 2022-04-25 11:19:26 +02:00
parent 6010133782
commit 75748a1d82
3 changed files with 11 additions and 6 deletions

View File

@ -17,7 +17,7 @@ from image_prediction.utils.generic import lift
class ParsablePDFImageExtractor(ImageExtractor): class ParsablePDFImageExtractor(ImageExtractor):
def __init__(self, verbose=False, tolerance=0, progress_message=None): def __init__(self, verbose=False, tolerance=0):
""" """
Args: Args:
@ -27,7 +27,6 @@ class ParsablePDFImageExtractor(ImageExtractor):
self.doc: fitz.fitz.Document = None self.doc: fitz.fitz.Document = None
self.verbose = verbose self.verbose = verbose
self.tolerance = tolerance self.tolerance = tolerance
self.progress_message = progress_message
def extract(self, pdf: bytes, page_range: range = None): def extract(self, pdf: bytes, page_range: range = None):
self.doc = fitz.Document(stream=pdf) self.doc = fitz.Document(stream=pdf)

View File

@ -17,7 +17,7 @@ def load_pipeline(**kwargs):
model_loader = get_mlflow_model_loader(MLRUNS_DIR) model_loader = get_mlflow_model_loader(MLRUNS_DIR)
model_identifier = CONFIG.service.run_id model_identifier = CONFIG.service.run_id
pipeline = Pipeline(model_loader, model_identifier, progress_message="Processing document", **kwargs) pipeline = Pipeline(model_loader, model_identifier, **kwargs)
return pipeline return pipeline
@ -31,7 +31,8 @@ def star(f):
class Pipeline: class Pipeline:
def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs): def __init__(self, model_loader, model_identifier, batch_size=16, verbose=True, **kwargs):
self.verbose = verbose
extract = get_extractor(**kwargs) extract = get_extractor(**kwargs)
classifier = get_image_classifier(model_loader, model_identifier) classifier = get_image_classifier(model_loader, model_identifier)
@ -55,4 +56,9 @@ class Pipeline:
) )
def __call__(self, pdf: bytes, page_range: range = None): def __call__(self, pdf: bytes, page_range: range = None):
yield from tqdm(self.pipe(pdf, page_range=page_range), desc="Processing images from document", unit=" images") yield from tqdm(
self.pipe(pdf, page_range=page_range),
desc="Processing images from document",
unit=" images",
disable=not self.verbose,
)

View File

@ -6,7 +6,7 @@ import requests
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--pdf_path", required=True) parser.add_argument("pdf_path")
args = parser.parse_args() args = parser.parse_args()
return args return args