diff --git a/pyinfra/server/receiver/__init__.py b/pyinfra/server/receiver/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pyinfra/server/receiver/receiver.py b/pyinfra/server/receiver/receiver.py new file mode 100644 index 0000000..27b6d06 --- /dev/null +++ b/pyinfra/server/receiver/receiver.py @@ -0,0 +1,9 @@ +import abc +from typing import Iterable + + +class Receiver(abc.ABC): + + @abc.abstractmethod + def __call__(self, package: Iterable): + pass diff --git a/pyinfra/server/receiver/receivers/__init__.py b/pyinfra/server/receiver/receivers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pyinfra/server/rest.py b/pyinfra/server/rest.py index 973410c..be45e17 100644 --- a/pyinfra/server/rest.py +++ b/pyinfra/server/rest.py @@ -5,9 +5,8 @@ from typing import Iterable from funcy import identity, rcompose, flatten from pyinfra.server.packer.packers.rest import RestPacker -from pyinfra.server.sender.senders.rest import RestSender - -from pyinfra.server.utils import stream_response_payloads, extract_payload_from_responses +from pyinfra.server.sender.sender import RestServer, Sender +from pyinfra.server.utils import stream_response_payloads from pyinfra.utils.func import lift logger = logging.getLogger("PIL.PngImagePlugin") @@ -55,8 +54,7 @@ def head(endpoint): """Sends packages of data and metadata to endpoint and returns response.""" return rcompose( RestPacker(), - RestSender(endpoint), - extract_payload_from_responses, + Sender(RestServer(endpoint)), )(data, metadata) return send diff --git a/pyinfra/server/sender/sender.py b/pyinfra/server/sender/sender.py index 5c7b023..8fba62d 100644 --- a/pyinfra/server/sender/sender.py +++ b/pyinfra/server/sender/sender.py @@ -1,9 +1,63 @@ import abc from typing import Iterable +import flask +import requests +from more_itertools import peekable -class Sender(abc.ABC): + +class Nothing: + pass + + +def has_next(peekable_iter): + return peekable_iter.peek(Nothing) != Nothing + + +class Server(abc.ABC): + @abc.abstractmethod + def patch(self, package): + pass @abc.abstractmethod - def __call__(self, package: Iterable): + def post(self, package): pass + + +class RestServer(Server): + def __init__(self, endpoint): + self.endpoint = endpoint + + def patch(self, package): + return requests.patch(self.endpoint, json=package) + + def post(self, package): + return requests.post(self.endpoint, json=package) + + +class Sender: + def __init__(self, server: Server): + self.server = server + + def __call__(self, packages: Iterable[dict]): + + packages = peekable(packages) + + for package in packages: + method = self.server.patch if has_next(packages) else self.server.post + response = method(package) + yield response + + +class Receiver(abc.ABC): + @abc.abstractmethod + def __call__(self, responses: Iterable): + pass + + +class RestReceiver(abc.ABC): + @abc.abstractmethod + def __call__(self, responses: Iterable[requests.Response]): + for response in responses: + response.raise_for_status() + yield response.json() diff --git a/pyinfra/server/sender/senders/rest.py b/pyinfra/server/sender/senders/rest.py index 53ecbd1..ed14e32 100644 --- a/pyinfra/server/sender/senders/rest.py +++ b/pyinfra/server/sender/senders/rest.py @@ -26,4 +26,4 @@ class RestSender(Sender): method = requests.patch if has_next(packages) else requests.post response = method(self.endpoint, json=package) response.raise_for_status() - yield response + yield response.json() diff --git a/pyinfra/server/utils.py b/pyinfra/server/utils.py index ebd9606..68259f3 100644 --- a/pyinfra/server/utils.py +++ b/pyinfra/server/utils.py @@ -34,8 +34,8 @@ def dispatch_http_method_left_and_forward_data_right(*args): return parallel_map(dispatch_methods, lift(identity))(*args) -def extract_payload_from_responses(payloads): - return map(methodcaller("json"), payloads) +# def extract_payload_from_responses(payloads): +# return map(methodcaller("json"), payloads) def pack(data: bytes, metadata: dict): diff --git a/test/unit_tests/rest/receiver_test.py b/test/unit_tests/rest/receiver_test.py new file mode 100644 index 0000000..12b8c72 --- /dev/null +++ b/test/unit_tests/rest/receiver_test.py @@ -0,0 +1,10 @@ +# import pytest +# +# from pyinfra.server.sender.sender import RestServer, Sender +# +# +# @pytest.mark.parametrize("batched", [True, False]) +# @pytest.mark.parametrize("item_type", ["string", "image", "pdf"]) +# def test_rest_receiver(url, packages, server_process): +# sender = Sender(RestServer(f"{url}/process")) +# assert all([r.status_code == 200 for r in receiver(sender(packages))]) diff --git a/test/unit_tests/rest/sender_test.py b/test/unit_tests/rest/sender_test.py index c6a6d07..dc3a56c 100644 --- a/test/unit_tests/rest/sender_test.py +++ b/test/unit_tests/rest/sender_test.py @@ -1,10 +1,10 @@ import pytest -from pyinfra.server.sender.senders.rest import RestSender +from pyinfra.server.sender.sender import RestServer, Sender @pytest.mark.parametrize("batched", [True, False]) @pytest.mark.parametrize("item_type", ["string", "image", "pdf"]) -def test_rest_packer(url, packages, server_process): - sender = RestSender(f"{url}/process") +def test_rest_sender(url, packages, server_process): + sender = Sender(RestServer(f"{url}/process")) assert all([r.status_code == 200 for r in sender(packages)])