diff --git a/image_prediction/formatter/__init__.py b/image_prediction/formatter/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/formatter/formatter.py b/image_prediction/formatter/formatter.py new file mode 100644 index 0000000..2d71ee5 --- /dev/null +++ b/image_prediction/formatter/formatter.py @@ -0,0 +1,6 @@ +import abc + + +class Formatter(abc.ABC): + def format(self, info: dict): + pass diff --git a/image_prediction/formatter/formatters/__init__.py b/image_prediction/formatter/formatters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/formatter/formatters/info_formatter.py b/image_prediction/formatter/formatters/info_formatter.py new file mode 100644 index 0000000..ad658af --- /dev/null +++ b/image_prediction/formatter/formatters/info_formatter.py @@ -0,0 +1,13 @@ +from enum import Enum +from typing import List + +from image_prediction.formatter.formatter import Formatter + + +class EnumFormatter(Formatter): + + def format(self, metadata: dict): + return {key.value if isinstance(key, Enum) else key: val for key, val in metadata.items()} + + def __call__(self, metadata: List[dict]): + return map(self.format, metadata) diff --git a/image_prediction/pipeline.py b/image_prediction/pipeline.py index 1dd6995..d680e96 100644 --- a/image_prediction/pipeline.py +++ b/image_prediction/pipeline.py @@ -1,5 +1,9 @@ import os +from funcy import rcompose + +from image_prediction.formatter.formatters.info_formatter import EnumFormatter + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" from image_prediction.classifier.classifier import Classifier @@ -36,10 +40,16 @@ def get_extractor_classifier(): return extractor_classifier +def get_formatter(): + formatter = EnumFormatter() + + return formatter + + class Pipeline: def __init__(self): - self.pipeline = get_extractor_classifier() + self.pipe = rcompose(get_extractor_classifier(), get_formatter()) def __call__(self, pdf: bytes): - return self.pipeline(pdf) + return self.pipe(pdf)