refactoring
This commit is contained in:
parent
6010133782
commit
75748a1d82
@ -17,7 +17,7 @@ from image_prediction.utils.generic import lift
|
|||||||
|
|
||||||
|
|
||||||
class ParsablePDFImageExtractor(ImageExtractor):
|
class ParsablePDFImageExtractor(ImageExtractor):
|
||||||
def __init__(self, verbose=False, tolerance=0, progress_message=None):
|
def __init__(self, verbose=False, tolerance=0):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -27,7 +27,6 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
|||||||
self.doc: fitz.fitz.Document = None
|
self.doc: fitz.fitz.Document = None
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.tolerance = tolerance
|
self.tolerance = tolerance
|
||||||
self.progress_message = progress_message
|
|
||||||
|
|
||||||
def extract(self, pdf: bytes, page_range: range = None):
|
def extract(self, pdf: bytes, page_range: range = None):
|
||||||
self.doc = fitz.Document(stream=pdf)
|
self.doc = fitz.Document(stream=pdf)
|
||||||
|
|||||||
@ -17,7 +17,7 @@ def load_pipeline(**kwargs):
|
|||||||
model_loader = get_mlflow_model_loader(MLRUNS_DIR)
|
model_loader = get_mlflow_model_loader(MLRUNS_DIR)
|
||||||
model_identifier = CONFIG.service.run_id
|
model_identifier = CONFIG.service.run_id
|
||||||
|
|
||||||
pipeline = Pipeline(model_loader, model_identifier, progress_message="Processing document", **kwargs)
|
pipeline = Pipeline(model_loader, model_identifier, **kwargs)
|
||||||
|
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
@ -31,7 +31,8 @@ def star(f):
|
|||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs):
|
def __init__(self, model_loader, model_identifier, batch_size=16, verbose=True, **kwargs):
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
extract = get_extractor(**kwargs)
|
extract = get_extractor(**kwargs)
|
||||||
classifier = get_image_classifier(model_loader, model_identifier)
|
classifier = get_image_classifier(model_loader, model_identifier)
|
||||||
@ -55,4 +56,9 @@ class Pipeline:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, pdf: bytes, page_range: range = None):
|
def __call__(self, pdf: bytes, page_range: range = None):
|
||||||
yield from tqdm(self.pipe(pdf, page_range=page_range), desc="Processing images from document", unit=" images")
|
yield from tqdm(
|
||||||
|
self.pipe(pdf, page_range=page_range),
|
||||||
|
desc="Processing images from document",
|
||||||
|
unit=" images",
|
||||||
|
disable=not self.verbose,
|
||||||
|
)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import requests
|
|||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--pdf_path", required=True)
|
parser.add_argument("pdf_path")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user