refactoring
This commit is contained in:
parent
8abc901004
commit
b582726cd1
@ -1,8 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
from functools import partial
|
from functools import partial, reduce
|
||||||
from itertools import chain, tee
|
from itertools import chain, tee
|
||||||
|
|
||||||
from funcy import rcompose, first, compose, second, chunks, identity
|
from funcy import rcompose, first, compose, second, chunks, identity, curry
|
||||||
|
|
||||||
from image_prediction.config import CONFIG
|
from image_prediction.config import CONFIG
|
||||||
from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor
|
from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor
|
||||||
@ -25,7 +25,7 @@ def parallel(*fs):
|
|||||||
return lambda *args: (f(a) for f, a in zip(fs, args))
|
return lambda *args: (f(a) for f, a in zip(fs, args))
|
||||||
|
|
||||||
|
|
||||||
def splat(f):
|
def star(f):
|
||||||
return lambda x: f(*x)
|
return lambda x: f(*x)
|
||||||
|
|
||||||
|
|
||||||
@ -35,18 +35,19 @@ class Pipeline:
|
|||||||
classifier = get_image_classifier(model_loader, model_identifier)
|
classifier = get_image_classifier(model_loader, model_identifier)
|
||||||
reformat = get_formatter()
|
reformat = get_formatter()
|
||||||
|
|
||||||
split = compose(splat(parallel(*map(lift, (first, second)))), tee)
|
split = compose(star(parallel(*map(lift, (first, second)))), tee)
|
||||||
classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size))
|
classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size))
|
||||||
join = compose(starlift(lambda prd, mdt: {"classification": prd, **mdt}), splat(zip))
|
pairwise_apply = compose(star, parallel)
|
||||||
|
join = compose(starlift(lambda prd, mdt: {"classification": prd, **mdt}), star(zip))
|
||||||
|
|
||||||
# +>--classify--v
|
# +>--classify--v
|
||||||
# --extract-->--split--| |--join-->format
|
# --extract-->--split--| |--join-->reformat
|
||||||
# +>--identity--^
|
# +>--identity--^
|
||||||
|
|
||||||
self.pipe = rcompose(
|
self.pipe = rcompose(
|
||||||
extract, # ... image-metadata-pairs as a stream
|
extract, # ... image-metadata-pairs as a stream
|
||||||
split, # ... into an image stream and a metadata stream
|
split, # ... into an image stream and a metadata stream
|
||||||
splat(parallel(classify, identity)), # ... process streams independently
|
pairwise_apply(classify, identity), # ... apply functions to the streams pairwise
|
||||||
join, # ... the streams
|
join, # ... the streams
|
||||||
reformat, # ... the items
|
reformat, # ... the items
|
||||||
)
|
)
|
||||||
|
|||||||
@ -37,12 +37,12 @@ def test_server_predict_failure(client, mute_logger):
|
|||||||
|
|
||||||
|
|
||||||
def test_server_health_check(client):
|
def test_server_health_check(client):
|
||||||
response = client.get("/ready")
|
response = client.get("/health")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json == "OK"
|
assert response.json == "OK"
|
||||||
|
|
||||||
|
|
||||||
def test_server_ready_check(client):
|
def test_server_ready_check(client):
|
||||||
response = client.get("/health")
|
response = client.get("/ready")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json == "OK"
|
assert response.json == "OK"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user