refactoring

This commit is contained in:
Matthias Bisping 2022-04-22 17:12:11 +02:00
parent 8abc901004
commit b582726cd1
2 changed files with 10 additions and 9 deletions

View File

@ -1,8 +1,8 @@
import os import os
from functools import partial from functools import partial, reduce
from itertools import chain, tee from itertools import chain, tee
from funcy import rcompose, first, compose, second, chunks, identity from funcy import rcompose, first, compose, second, chunks, identity, curry
from image_prediction.config import CONFIG from image_prediction.config import CONFIG
from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor
@ -25,7 +25,7 @@ def parallel(*fs):
return lambda *args: (f(a) for f, a in zip(fs, args)) return lambda *args: (f(a) for f, a in zip(fs, args))
def splat(f): def star(f):
return lambda x: f(*x) return lambda x: f(*x)
@ -35,18 +35,19 @@ class Pipeline:
classifier = get_image_classifier(model_loader, model_identifier) classifier = get_image_classifier(model_loader, model_identifier)
reformat = get_formatter() reformat = get_formatter()
split = compose(splat(parallel(*map(lift, (first, second)))), tee) split = compose(star(parallel(*map(lift, (first, second)))), tee)
classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size)) classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size))
join = compose(starlift(lambda prd, mdt: {"classification": prd, **mdt}), splat(zip)) pairwise_apply = compose(star, parallel)
join = compose(starlift(lambda prd, mdt: {"classification": prd, **mdt}), star(zip))
# +>--classify--v # +>--classify--v
# --extract-->--split--| |--join-->format # --extract-->--split--| |--join-->reformat
# +>--identity--^ # +>--identity--^
self.pipe = rcompose( self.pipe = rcompose(
extract, # ... image-metadata-pairs as a stream extract, # ... image-metadata-pairs as a stream
split, # ... into an image stream and a metadata stream split, # ... into an image stream and a metadata stream
splat(parallel(classify, identity)), # ... process streams independently pairwise_apply(classify, identity), # ... apply functions to the streams pairwise
join, # ... the streams join, # ... the streams
reformat, # ... the items reformat, # ... the items
) )

View File

@ -37,12 +37,12 @@ def test_server_predict_failure(client, mute_logger):
def test_server_health_check(client): def test_server_health_check(client):
response = client.get("/ready") response = client.get("/health")
assert response.status_code == 200 assert response.status_code == 200
assert response.json == "OK" assert response.json == "OK"
def test_server_ready_check(client): def test_server_ready_check(client):
response = client.get("/health") response = client.get("/ready")
assert response.status_code == 200 assert response.status_code == 200
assert response.json == "OK" assert response.json == "OK"