diff --git a/image_prediction/flask.py b/image_prediction/flask.py index 8cd5111..97cd999 100644 --- a/image_prediction/flask.py +++ b/image_prediction/flask.py @@ -28,7 +28,7 @@ def make_prediction_server(predict_fn: Callable): resp.status_code = 200 return resp - @app.route("/", methods=["POST"]) + @app.route("/predict", methods=["POST"]) def predict(): def predict_fn_wrapper(pdf, return_dict): return_dict["result"] = predict_fn(pdf) diff --git a/test/integration_tests/server_test.py b/test/integration_tests/server_test.py index 1445a4b..f7ef3ae 100644 --- a/test/integration_tests/server_test.py +++ b/test/integration_tests/server_test.py @@ -1,16 +1,11 @@ -import os import socket from multiprocessing import Process -import coverage import pytest import requests from funcy import retry from image_prediction.flask import make_prediction_server, run_prediction_server -from image_prediction.locations import COVERAGERC - -# os.environ['COVERAGE_PROCESS_START'] = COVERAGERC @pytest.fixture @@ -35,16 +30,8 @@ def url(host, port): @pytest.fixture -def predict_fn(): - def predict(_): - return 42 - - return predict - - -@pytest.fixture -def server(predict_fn): - server = make_prediction_server(predict_fn) +def server(): + server = make_prediction_server(lambda _: 42) return server @@ -63,7 +50,6 @@ def server_ready(url): @pytest.fixture(autouse=True, scope="function") def server_process(server, host_and_port, url): def get_server_process(): - # coverage.process_startup() return Process(target=run_prediction_server, kwargs={"app": server, **host_and_port}) server = get_server_process() @@ -78,7 +64,7 @@ def server_process(server, host_and_port, url): def test_server_predict(url): - response = requests.post(url) + response = requests.post(f"{url}/predict") response.raise_for_status() assert response.json() == 42 diff --git a/test/unit_tests/mocked_server_test.py b/test/unit_tests/mocked_server_test.py new file mode 100644 index 0000000..b6bb083 --- /dev/null +++ b/test/unit_tests/mocked_server_test.py @@ -0,0 +1,32 @@ +import json + +import pytest + +from image_prediction.flask import make_prediction_server + + +@pytest.fixture +def server(): + server = make_prediction_server(lambda _: 42) + server.config.update({"TESTING": True}) + return server + + +@pytest.fixture +def client(server): + return server.test_client() + + +def test_server_predict(client): + response = client.post("/predict") + assert json.loads(response.data) == 42 + + +def test_server_health_check(client): + response = client.get("/health") + assert response.status_code == 200 + + +def test_server_ready_check(client): + response = client.get("/health") + assert response.status_code == 200