refactoring

This commit is contained in:
Matthias Bisping 2022-04-22 17:12:11 +02:00
parent 8abc901004
commit b582726cd1
2 changed files with 10 additions and 9 deletions

View File

@ -1,8 +1,8 @@
import os
from functools import partial
from functools import partial, reduce
from itertools import chain, tee
from funcy import rcompose, first, compose, second, chunks, identity
from funcy import rcompose, first, compose, second, chunks, identity, curry
from image_prediction.config import CONFIG
from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor
@ -25,7 +25,7 @@ def parallel(*fs):
return lambda *args: (f(a) for f, a in zip(fs, args))
def splat(f):
def star(f):
return lambda x: f(*x)
@ -35,18 +35,19 @@ class Pipeline:
classifier = get_image_classifier(model_loader, model_identifier)
reformat = get_formatter()
split = compose(splat(parallel(*map(lift, (first, second)))), tee)
split = compose(star(parallel(*map(lift, (first, second)))), tee)
classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size))
join = compose(starlift(lambda prd, mdt: {"classification": prd, **mdt}), splat(zip))
pairwise_apply = compose(star, parallel)
join = compose(starlift(lambda prd, mdt: {"classification": prd, **mdt}), star(zip))
# +>--classify--v
# --extract-->--split--| |--join-->format
# --extract-->--split--| |--join-->reformat
# +>--identity--^
self.pipe = rcompose(
extract, # ... image-metadata-pairs as a stream
split, # ... into an image stream and a metadata stream
splat(parallel(classify, identity)), # ... process streams independently
pairwise_apply(classify, identity), # ... apply functions to the streams pairwise
join, # ... the streams
reformat, # ... the items
)

View File

@ -37,12 +37,12 @@ def test_server_predict_failure(client, mute_logger):
def test_server_health_check(client):
response = client.get("/ready")
response = client.get("/health")
assert response.status_code == 200
assert response.json == "OK"
def test_server_ready_check(client):
response = client.get("/health")
response = client.get("/ready")
assert response.status_code == 200
assert response.json == "OK"