import socket from multiprocessing import Process import pytest import requests from funcy import retry from waitress import serve from image_prediction.flask import make_prediction_server from image_prediction.pipeline import load_pipeline @pytest.fixture def host(): return "0.0.0.0" def get_free_port(host): sock = socket.socket() sock.bind((host, 0)) return sock.getsockname()[1] @pytest.fixture def port(host): return get_free_port(host) @pytest.fixture def url(host, port): return f"http://{host}:{port}" @pytest.fixture(params=["dummy", "actual"]) def server_type(request): return request.param @pytest.fixture def server(server_type): if server_type == "dummy": return make_prediction_server(lambda x: int(x.decode()) // 2) elif server_type == "actual": return make_prediction_server(lambda x: list(load_pipeline(verbose=False)(x))) else: raise ValueError(f"Unknown server type {server_type}.") @pytest.fixture def host_and_port(host, port): return {"host": host, "port": port} @retry(tries=5, timeout=1) def server_ready(url): response = requests.get(f"{url}/ready") response.raise_for_status() return response.status_code == 200 @pytest.fixture(autouse=True, scope="function") def server_process(server, host_and_port, url): def get_server_process(): return Process(target=serve, kwargs={"app": server, **host_and_port}) server = get_server_process() server.start() if server_ready(url): yield server.kill() server.join() server.close() @pytest.mark.parametrize("server_type", ["actual"]) def test_server_predict(url, real_pdf, real_expected_service_response): response = requests.post(f"{url}/predict", data=real_pdf) response.raise_for_status() assert response.json() == real_expected_service_response @pytest.mark.parametrize("server_type", ["dummy"]) def test_server_dummy_operation(url): response = requests.post(f"{url}/predict", data=b"42") response.raise_for_status() assert response.json() == 21 @pytest.mark.parametrize("server_type", ["dummy"]) def test_server_health_check(url): response = requests.get(f"{url}/health") response.raise_for_status() assert response.status_code == 200 @pytest.mark.parametrize("server_type", ["dummy"]) def test_server_ready_check(url): assert server_ready(url)