import io import logging import socket from collections import Counter from multiprocessing import Process from operator import itemgetter from typing import Generator import fitz import pytest import requests from PIL import Image from funcy import retry, project, omit from waitress import serve from pyinfra.server.dispatcher.dispatcher import Nothing from pyinfra.server.server import ( set_up_processing_server, ) from pyinfra.server.utils import make_streamable_and_wrap_in_packing_logic from pyinfra.utils.func import starlift from test.utils.image import image_to_bytes logger = logging.getLogger(__name__) @pytest.fixture def host(): return "0.0.0.0" def get_free_port(host): sock = socket.socket() sock.bind((host, 0)) return sock.getsockname()[1] @pytest.fixture def port(host): return get_free_port(host) @pytest.fixture def url(host, port): return f"http://{host}:{port}" @pytest.fixture def server(server_stream_function, buffer_size, operation_name): return set_up_processing_server({operation_name: server_stream_function}, buffer_size) @pytest.fixture def operation_name(many_to_n): return "multi_inp_op" if many_to_n else "single_inp_op" @pytest.fixture def server_stream_function(operation_conditionally_batched, batched): return make_streamable_and_wrap_in_packing_logic(operation_conditionally_batched, batched) @pytest.fixture def operation_conditionally_batched(operation, batched): return starlift(operation) if batched else operation @pytest.fixture def operation(core_operation, server_side_test): auto_counter = Counter() def auto_count(metadata): if not server_side_test: idnt = itemgetter("dossierId", "fileId")(metadata) auto_counter[idnt] += 1 return {**metadata, "id": auto_counter[idnt]} if "id" not in metadata else metadata else: return metadata def op(data, metadata): assert isinstance(metadata, dict) result = core_operation(data, metadata) if isinstance(result, Generator): for data, metadata in result: yield data, auto_count(omit(metadata, ["pages", "operation"])) else: data, metadata = result yield data, auto_count(omit(metadata, ["pages", "operation"])) if core_operation is Nothing: return Nothing return op @pytest.fixture(params=[False]) def server_side_test(request): return request.param @pytest.fixture def core_operation(item_type, one_to_many, analysis_task): def duplicate(string: bytes, metadata): for _ in range(2): yield upper(string, metadata), metadata def upper(string: bytes, metadata): return string.decode().upper().encode(), metadata def extract(string: bytes, metadata): for i, c in project(dict(enumerate(string.decode())), metadata["pages"]).items(): metadata["id"] = i yield c.encode(), metadata def rotate(im: bytes, metadata): im = Image.open(io.BytesIO(im)) return image_to_bytes(im.rotate(90)), metadata def classify(_: bytes, metadata): return b"", {"classification": 1, **metadata} def stream_pages(pdf: bytes, metadata): for i, page in enumerate(fitz.open(stream=pdf)): # yield page.get_pixmap().tobytes("png"), metadata metadata["id"] = i yield f"page_{i}".encode(), metadata params2op = { False: { "string": {False: upper}, "image": {False: rotate, True: classify}, }, True: { "string": {False: extract}, "pdf": {False: stream_pages}, }, } try: return params2op[one_to_many][item_type][analysis_task] except KeyError: msg = f"No operation defined for [{one_to_many=}, {item_type=}, {analysis_task=}]." pytest.skip(msg) logger.debug(msg) return Nothing @pytest.fixture(params=["pdf", "string", "image"]) def item_type(request): return request.param @pytest.fixture(params=[True, False]) def one_to_many(request): return request.param @pytest.fixture(params=[True, False]) def many_to_n(request): return request.param @pytest.fixture(params=[True, False]) def analysis_task(request): return request.param @pytest.fixture(params=[False, True]) def batched(request): """Controls, whether the buffer processor function of the webserver is applied to batches or single items.""" return request.param @pytest.fixture def host_and_port(host, port): return {"host": host, "port": port} @retry(tries=5, timeout=1) def server_ready(url): response = requests.get(f"{url}/ready") response.raise_for_status() return response.status_code == 200 @pytest.fixture(autouse=False, scope="function") def server_process(server, host_and_port, url): def get_server_process(): return Process(target=serve, kwargs={"app": server, **host_and_port}) server = get_server_process() server.start() if server_ready(url): yield server.kill() server.join() server.close()