diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..dc6ce29 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,106 @@ +/build_venv/ +/.venv/ +/misc/ +/incl/image_service/test/ +/scratch/ +/bamboo-specs/ +README.md +Dockerfile +*idea +*misc +*egg-innfo +*pycache* + +# Git +.git +.gitignore + +# CI +.codeclimate.yml +.travis.yml +.taskcluster.yml + +# Docker +docker-compose.yml +.docker + +# Byte-compiled / optimized / DLL files +__pycache__/ +*/__pycache__/ +*/*/__pycache__/ +*/*/*/__pycache__/ +*.py[cod] +*/*.py[cod] +*/*/*.py[cod] +*/*/*/*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.cache +nosetests.xml +coverage.xml + +# Translations +*.mo +*.pot + +# Django stuff: +*.log + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Virtual environment +.env/ +.venv/ +#venv/ + +# PyCharm +.idea + +# Python mode for VIM +.ropeproject +*/.ropeproject +*/*/.ropeproject +*/*/*/.ropeproject + +# Vim swap files +*.swp +*/*.swp +*/*/*.swp +*/*/*/*.swp \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 3124f9f..3ff31f4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,6 +8,12 @@ WORKDIR /app/service COPY ./src ./src COPY ./incl/detr ./incl/detr COPY ./fb_detr ./fb_detr +COPY ./setup.py ./setup.py +COPY ./requirements.txt ./requirements.txt +COPY ./config.yaml ./config.yaml + +# Install dependencies differing from base image. +RUN python3 -m pip install -r requirements.txt RUN python3 -m pip install -e . diff --git a/Dockerfile_base b/Dockerfile_base index d2a91c1..4085902 100644 --- a/Dockerfile_base +++ b/Dockerfile_base @@ -9,7 +9,8 @@ RUN python -m pip install --upgrade pip # Make a directory for the service files and copy the service repo into the container. WORKDIR /app/service -COPY . ./ +COPY ./requirements.txt ./requirements.txt +COPY ./data ./data # Install dependencies. RUN python3 -m pip install -r requirements.txt diff --git a/config.yaml b/config.yaml index f1b38a3..7569e6b 100644 --- a/config.yaml +++ b/config.yaml @@ -1,7 +1,15 @@ -device: cpu -threshold: .5 +estimator: + checkpoint: checkpoint.pth + classes: ["logo", "other", "formula", "signature", "handwriting_other"] + rejection_class: "other" + threshold: .5 + device: cpu -classes: ["logo", "other", "formula", "signature", "handwriting_other"] -rejection_class: "other" +webserver: + host: $SERVER_HOST|"127.0.0.1" # webserver address + port: $SERVER_PORT|5000 # webserver port + mode: $SERVER_MODE|production # webserver mode: {development, production} -checkpoint: checkpoint.pth +service: + logging_level: DEBUG + batch_size: $BATCH_SIZE|2 # Number of images in memory simultaneously per service instance diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..0afb54b --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,10 @@ +version: "3.3" +services: + detr-server: + image: detr-server + network_mode: "host" + read_only: true + volumes: + - tmp:/tmp:rw +volumes: + tmp: diff --git a/fb_detr/config.py b/fb_detr/config.py new file mode 100644 index 0000000..49dc564 --- /dev/null +++ b/fb_detr/config.py @@ -0,0 +1,40 @@ +"""Implements a config object with dot-indexing syntax.""" + + +from envyaml import EnvYAML + +from fb_detr.locations import CONFIG_FILE + + +def _get_item_and_maybe_make_dotindexable(container, item): + ret = container[item] + return DotIndexable(ret) if isinstance(ret, dict) else ret + + +class DotIndexable: + def __init__(self, x): + self.x = x + + def __getattr__(self, item): + return _get_item_and_maybe_make_dotindexable(self.x, item) + + def __setitem__(self, key, value): + self.x[key] = value + + def __repr__(self): + return self.x.__repr__() + + +class Config: + def __init__(self, config_path): + self.__config = EnvYAML(config_path) + + def __getattr__(self, item): + if item in self.__config: + return _get_item_and_maybe_make_dotindexable(self.__config, item) + + def __getitem__(self, item): + return self.__getattr__(item) + + +CONFIG = Config(CONFIG_FILE) diff --git a/fb_detr/predictor.py b/fb_detr/predictor.py index b7def19..a8f89eb 100644 --- a/fb_detr/predictor.py +++ b/fb_detr/predictor.py @@ -1,16 +1,19 @@ import argparse -from itertools import compress, starmap +import logging +from itertools import compress, starmap, chain from operator import itemgetter from pathlib import Path from typing import Iterable import torch -from detr.models import build_model -from detr.test import get_args_parser, infer from iteration_utilities import starfilter +from tqdm import tqdm +from detr.models import build_model +from detr.prediction import get_args_parser, infer +from fb_detr.config import CONFIG from fb_detr.utils.non_max_supprs import greedy_non_max_supprs -from fb_detr.utils.config import read_config +from fb_detr.utils.stream import stream_pages, chunk_iterable, get_page_count def load_model(checkpoint_path): @@ -21,7 +24,7 @@ def load_model(checkpoint_path): if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) - device = torch.device(read_config("device")) + device = torch.device(CONFIG.estimator.device) model, _, _ = build_model(args) @@ -102,6 +105,7 @@ class Predictor: def detections_present(_, prediction): return bool(prediction["classes"]) + # TODO: set page_idx even when not filtering def build_return_dict(page_idx, predictions): return {"page_idx": page_idx, **predictions} @@ -114,22 +118,39 @@ class Predictor: def format_predictions(self, outputs: Iterable): return map(self.__format_prediction, outputs) - def __merge_boxes(self, predictions): + def __non_max_supprs(self, predictions): predictions = map(greedy_non_max_supprs, predictions) return predictions def predict(self, images, threshold=None): if not threshold: - threshold = read_config("threshold") + threshold = CONFIG.estimator.threshold - predictions = infer(images, self.model, read_config("device"), threshold) + predictions = infer(images, self.model, CONFIG.estimator.device, threshold) predictions = self.format_predictions(predictions) if self.rejection_class: predictions = self.filter_predictions(predictions) - predictions = self.__merge_boxes(predictions) + predictions = self.__non_max_supprs(predictions) predictions = list(predictions) return predictions + + def predict_pdf(self, pdf: bytes): + def predict_batch(batch_idx, batch): + predictions = self.predict(batch) + for p in predictions: + p["page_idx"] += batch_idx + + return predictions + + page_count = get_page_count(pdf) + batch_count = int(page_count / CONFIG.service.batch_size) + + page_stream = stream_pages(pdf) + page_batches = chunk_iterable(page_stream, CONFIG.service.batch_size) + predictions = list(chain(*starmap(predict_batch, tqdm(enumerate(page_batches), total=batch_count)))) + + return predictions diff --git a/fb_detr/utils/config.py b/fb_detr/utils/config.py deleted file mode 100644 index c0d9334..0000000 --- a/fb_detr/utils/config.py +++ /dev/null @@ -1,18 +0,0 @@ -import yaml - -from fb_detr.locations import CONFIG_FILE - - -def read_config(key, config_path: str = CONFIG_FILE): - """Reads the values associated with a key from a config. - - Args: - key: Key to look up the value to. - config_path: Path to config. - - Returns: - The value associated with `key`. - """ - with open(config_path) as f: - config = yaml.load(f, Loader=yaml.FullLoader) - return config[key] diff --git a/fb_detr/utils/estimator.py b/fb_detr/utils/estimator.py new file mode 100644 index 0000000..04c6922 --- /dev/null +++ b/fb_detr/utils/estimator.py @@ -0,0 +1,32 @@ +import os + +from fb_detr.config import CONFIG +from fb_detr.locations import DATA_DIR, TORCH_HOME +from fb_detr.predictor import Predictor + + +def suppress_userwarnings(): + import warnings + + warnings.filterwarnings("ignore") + + +def load_classes(): + classes = CONFIG.estimator.classes + id2class = dict(zip(range(1, len(classes) + 1), classes)) + return id2class + + +def get_checkpoint(): + return DATA_DIR / CONFIG.estimator.checkpoint + + +def set_torch_env(): + os.environ["TORCH_HOME"] = str(TORCH_HOME) + + +def initialize_predictor(resume): + set_torch_env() + checkpoint = get_checkpoint() if not resume else resume + predictor = Predictor(checkpoint, classes=load_classes(), rejection_class=CONFIG.estimator.rejection_class) + return predictor diff --git a/fb_detr/utils/stream.py b/fb_detr/utils/stream.py new file mode 100644 index 0000000..d9948a3 --- /dev/null +++ b/fb_detr/utils/stream.py @@ -0,0 +1,20 @@ +from itertools import takewhile, starmap, islice, repeat +from operator import truth + +from pdf2image import pdf2image + + +def chunk_iterable(iterable, n): + return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), n))))) + + +def get_page_count(pdf): + return pdf2image.pdfinfo_from_bytes(pdf)["Pages"] + + +def stream_pages(pdf): + def page_to_image(idx): + return pdf2image.convert_from_bytes(pdf, first_page=idx, last_page=idx + 1)[0] + + page_count = get_page_count(pdf) + return map(page_to_image, range(page_count)) diff --git a/incl/detr b/incl/detr index ad13cae..7720238 160000 --- a/incl/detr +++ b/incl/detr @@ -1 +1 @@ -Subproject commit ad13caea7faf7f0285290d20a58097ae273e4c24 +Subproject commit 772023801e4fd3deef7953f7f49fd6fb2bf60236 diff --git a/requirements.txt b/requirements.txt index e28f04f..c4acf39 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,10 @@ torch==1.10.2 numpy==1.22.1 -#opencv-python==4.5.5.62 opencv-python-headless==4.5.5.62 torchvision==0.11.3 pycocotools==2.0.4 scipy==1.7.3 pdf2image==1.16.0 -PyYAML==6.0 Flask==2.0.2 requests==2.27.1 iteration-utilities==0.11.0 @@ -14,3 +12,5 @@ dvc==2.9.3 dvc[ssh] frozendict==2.3.0 waitress==2.0.0 +envyaml==1.10.211231 +# PyYAML==6.0 diff --git a/scripts/predict.py b/scripts/predict.py deleted file mode 100644 index 3c13358..0000000 --- a/scripts/predict.py +++ /dev/null @@ -1,58 +0,0 @@ -import argparse -import json -from pathlib import Path - -from detr.test import draw_boxes -from pdf2image import pdf2image - -from fb_detr.predictor import Predictor - - -def parse_args(): - - parser = argparse.ArgumentParser() - parser.add_argument("--resume", required=True) - parser.add_argument("--output_dir", required=True) - parser.add_argument("--pdf_path") - parser.add_argument("--draw_boxes", default=False, action="store_true") - - args = parser.parse_args() - - return args - - -def build_image_paths(image_root_dir): - return [*map(str, Path(image_root_dir).glob("*.png"))] - - -def pdf_to_pages(pdf_path): - pages = pdf2image.convert_from_path(pdf_path) - return pages - - -def main(): - - # TDOO: de-hardcode - - classes = {1: "logo", 2: "other", 3: "formula", 4: "signature", 5: "handwriting_other"} - - args = parse_args() - predictor = Predictor(args.resume, classes=classes, rejection_class="other") - - images = pdf_to_pages(args.pdf_path) - outputs = predictor.predict(images, 0.5) - - if args.draw_boxes: - for im, o in zip(images, outputs): - if len(o["bboxes"]): - draw_boxes(image=im, **o, output_path=args.output_dir) - - else: - outputs = predictor.format_predictions(outputs) - outputs = predictor.filter_predictions(outputs) - for o in outputs: - print(json.dumps(o, indent=2)) - - -if __name__ == "__main__": - main() diff --git a/setup/docker.sh b/setup/docker.sh index f23d4aa..f6b670e 100755 --- a/setup/docker.sh +++ b/setup/docker.sh @@ -11,5 +11,5 @@ dvc pull git submodule update --init --recursive -docker build -f Dockerfile_base -t detr-server-base . +docker build -f Dockerfile_base -t fb_detr_prediction_container-base . docker build -f Dockerfile -t detr-server . diff --git a/src/serve.py b/src/serve.py index 1b8cb7a..76cddf5 100644 --- a/src/serve.py +++ b/src/serve.py @@ -1,20 +1,16 @@ import argparse +import json import logging -import os +from itertools import chain +from typing import Callable from flask import Flask, request, jsonify from pdf2image import pdf2image +from waitress import serve -from fb_detr.locations import DATA_DIR -from fb_detr.locations import TORCH_HOME -from fb_detr.predictor import Predictor -from fb_detr.utils.config import read_config - - -def suppress_userwarnings(): - import warnings - - warnings.filterwarnings("ignore") +from fb_detr.config import CONFIG +from fb_detr.utils.estimator import suppress_userwarnings, initialize_predictor +from fb_detr.utils.stream import stream_pages, chunk_iterable def parse_args(): @@ -26,36 +22,19 @@ def parse_args(): return args -def load_classes(): - classes = read_config("classes") - id2class = dict(zip(range(1, len(classes) + 1), classes)) - return id2class - - -def get_checkpoint(): - return DATA_DIR / read_config("checkpoint") - - -def set_torch_env(): - os.environ["TORCH_HOME"] = str(TORCH_HOME) - - def main(args): - if not args.warnings: suppress_userwarnings() - run_server(args.resume) + predictor = initialize_predictor(args.resume) + logging.info("Predictor ready.") + + prediction_server = make_prediction_server(predictor.predict_pdf) + + run_prediction_server(prediction_server, mode=CONFIG.webserver.mode) -def run_server(resume): - - set_torch_env() - - def initialize_predictor(): - checkpoint = get_checkpoint() if not resume else resume - predictor = Predictor(checkpoint, classes=load_classes(), rejection_class=read_config("rejection_class")) - return predictor +def make_prediction_server(predict_fn: Callable): app = Flask(__name__) @@ -72,37 +51,54 @@ def run_server(resume): return resp @app.route("/", methods=["POST"]) - def predict_request(): - def inner(): + def predict(): + def __predict(): - pdf = request.data + def inner(): - pages = pdf2image.convert_from_bytes(pdf) - predictions = predictor.predict(pages) + pdf = request.data - return jsonify(list(predictions)) + logging.debug("Running predictor on document...") + predictions = predict_fn(pdf) + logging.debug(f"Found {sum(map(len, predictions))} images in document.") + response = jsonify(list(predictions)) + + return response + + logging.info(f"Analyzing...") + result = inner() + logging.info("Analysis completed.") + return result try: - return inner() + return __predict() except Exception as err: - logging.warning("Analysis failed") + logging.warning("Analysis failed.") logging.exception(err) - resp = jsonify("Analysis failed") - resp.status_code = 500 - return resp + response = jsonify("Analysis failed.") + response.status_code = 500 + return response - @app.route("/status", methods=["GET"]) - def status(): - response = "OK" - return jsonify(response) + return app - predictor = initialize_predictor() - logging.info("Predictor ready.") +def run_prediction_server(app, mode="development"): - app.run(host="127.0.0.1", port=5000, debug=True) + if mode == "development": + app.run(host=CONFIG.webserver.host, port=CONFIG.webserver.port, debug=True) + elif mode == "production": + serve(app, host=CONFIG.webserver.host, port=CONFIG.webserver.port) if __name__ == "__main__": + + logging_level = CONFIG.service.logging_level + logging.basicConfig(level=logging_level) + logging.getLogger("flask").setLevel(logging.ERROR) + logging.getLogger("urllib3").setLevel(logging.ERROR) + logging.getLogger("werkzeug").setLevel(logging.ERROR) + logging.getLogger("waitress").setLevel(logging.ERROR) + args = parse_args() + main(args)