From b582726cd1de233edb55c5a76c91e99f9dd3bd13 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Fri, 22 Apr 2022 17:12:11 +0200 Subject: [PATCH] refactoring --- image_prediction/pipeline.py | 15 ++++++++------- test/unit_tests/mocked_server_test.py | 4 ++-- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/image_prediction/pipeline.py b/image_prediction/pipeline.py index caa59d0..c14b08b 100644 --- a/image_prediction/pipeline.py +++ b/image_prediction/pipeline.py @@ -1,8 +1,8 @@ import os -from functools import partial +from functools import partial, reduce from itertools import chain, tee -from funcy import rcompose, first, compose, second, chunks, identity +from funcy import rcompose, first, compose, second, chunks, identity, curry from image_prediction.config import CONFIG from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor @@ -25,7 +25,7 @@ def parallel(*fs): return lambda *args: (f(a) for f, a in zip(fs, args)) -def splat(f): +def star(f): return lambda x: f(*x) @@ -35,18 +35,19 @@ class Pipeline: classifier = get_image_classifier(model_loader, model_identifier) reformat = get_formatter() - split = compose(splat(parallel(*map(lift, (first, second)))), tee) + split = compose(star(parallel(*map(lift, (first, second)))), tee) classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size)) - join = compose(starlift(lambda prd, mdt: {"classification": prd, **mdt}), splat(zip)) + pairwise_apply = compose(star, parallel) + join = compose(starlift(lambda prd, mdt: {"classification": prd, **mdt}), star(zip)) # +>--classify--v - # --extract-->--split--| |--join-->format + # --extract-->--split--| |--join-->reformat # +>--identity--^ self.pipe = rcompose( extract, # ... image-metadata-pairs as a stream split, # ... into an image stream and a metadata stream - splat(parallel(classify, identity)), # ... process streams independently + pairwise_apply(classify, identity), # ... apply functions to the streams pairwise join, # ... the streams reformat, # ... the items ) diff --git a/test/unit_tests/mocked_server_test.py b/test/unit_tests/mocked_server_test.py index 8b8a692..d64c937 100644 --- a/test/unit_tests/mocked_server_test.py +++ b/test/unit_tests/mocked_server_test.py @@ -37,12 +37,12 @@ def test_server_predict_failure(client, mute_logger): def test_server_health_check(client): - response = client.get("/ready") + response = client.get("/health") assert response.status_code == 200 assert response.json == "OK" def test_server_ready_check(client): - response = client.get("/health") + response = client.get("/ready") assert response.status_code == 200 assert response.json == "OK"