diff --git a/pyinfra/callback.py b/pyinfra/callback.py new file mode 100644 index 0000000..0b3d817 --- /dev/null +++ b/pyinfra/callback.py @@ -0,0 +1,76 @@ +import logging + +from funcy import merge, omit, lmap + +from pyinfra.exceptions import AnalysisFailure + +logger = logging.getLogger(__name__) + + +class Callback: + """This is the callback that is applied to items pulled from the storage. It forwards these items to an analysis + endpoint. + """ + + def __init__(self, base_url, pipeline_factory): + self.base_url = base_url + self.pipeline_factory = pipeline_factory + self.endpoint2pipeline = {} + + def __make_endpoint(self, operation): + return f"{self.base_url}/{operation}" + + def __get_pipeline(self, endpoint): + if endpoint in self.endpoint2pipeline: + pipeline = self.endpoint2pipeline[endpoint] + + else: + pipeline = self.pipeline_factory(endpoint) + self.endpoint2pipeline[endpoint] = pipeline + + return pipeline + + @staticmethod + def __run_pipeline(pipeline, body): + """ + TODO: Since data and metadata are passed as singletons, there is no buffering and hence no batching happening + within the pipeline. However, the queue acknowledgment logic needs to be changed in order to facilitate + passing non-singletons, to only ack a message, once a response is pulled from the output queue of the + pipeline. Probably the pipeline return value needs to contains the queue message frame (or so), in order for + the queue manager to tell which message to ack. + + TODO: casting list (lmap) on `analysis_response_stream` is a temporary solution, while the client pipeline + operates on singletons ([data], [metadata]). + """ + + def combine_storage_item_metadata_with_queue_message_metadata(body): + return merge(body["metadata"], omit(body, ["data", "metadata"])) + + def remove_queue_message_metadata(result): + metadata = omit(result["metadata"], queue_message_keys(body)) + return {**result, "metadata": metadata} + + def queue_message_keys(body): + return {*body.keys()}.difference({"data", "metadata"}) + + try: + data = body["data"] + metadata = combine_storage_item_metadata_with_queue_message_metadata(body) + analysis_response_stream = pipeline([data], [metadata]) + analysis_response_stream = lmap(remove_queue_message_metadata, analysis_response_stream) + return analysis_response_stream + + except Exception as err: + logger.error(err) + raise AnalysisFailure from err + + def __call__(self, body: dict): + operation = body.get("operation", "submit") + endpoint = self.__make_endpoint(operation) + pipeline = self.__get_pipeline(endpoint) + + try: + logging.debug(f"Requesting analysis from {endpoint}...") + return self.__run_pipeline(pipeline, body) + except AnalysisFailure: + logging.warning(f"Exception caught when calling analysis endpoint {endpoint}.") diff --git a/pyinfra/component_factory.py b/pyinfra/component_factory.py new file mode 100644 index 0000000..17733e0 --- /dev/null +++ b/pyinfra/component_factory.py @@ -0,0 +1,109 @@ +import logging +from functools import lru_cache + +from funcy import project, identity, rcompose + +from pyinfra.callback import Callback +from pyinfra.config import parse_disjunction_string +from pyinfra.file_descriptor_manager import FileDescriptorManager +from pyinfra.queue.consumer import Consumer +from pyinfra.queue.queue_manager.pika_queue_manager import PikaQueueManager +from pyinfra.server.client_pipeline import ClientPipeline +from pyinfra.server.dispatcher.dispatchers.rest import RestDispatcher +from pyinfra.server.interpreter.interpreters.rest_callback import RestPickupStreamer +from pyinfra.server.packer.packers.rest import RestPacker +from pyinfra.server.receiver.receivers.rest import RestReceiver +from pyinfra.storage import storages +from pyinfra.visitor import QueueVisitor +from pyinfra.visitor.downloader import Downloader +from pyinfra.visitor.response_formatter.formatters.default import DefaultResponseFormatter +from pyinfra.visitor.response_formatter.formatters.identity import IdentityResponseFormatter +from pyinfra.visitor.strategies.response.aggregation import AggregationStorageStrategy, ProjectingUploadFormatter + +logger = logging.getLogger(__name__) + + +class ComponentFactory: + def __init__(self, config): + self.config = config + + @lru_cache(maxsize=None) + def get_consumer(self, callback=None): + callback = callback or self.get_callback() + return Consumer(self.get_visitor(callback), self.get_queue_manager()) + + @lru_cache(maxsize=None) + def get_callback(self, analysis_base_url=None): + analysis_base_url = analysis_base_url or self.config.rabbitmq.callback.analysis_endpoint + + callback = Callback(analysis_base_url, self.get_pipeline) + + def wrapped(body): + body_repr = project(body, ["dossierId", "fileId", "operation"]) + logger.info(f"Processing {body_repr}...") + result = callback(body) + logger.info(f"Completed processing {body_repr}...") + return result + + return wrapped + + @lru_cache(maxsize=None) + def get_visitor(self, callback): + return QueueVisitor( + callback=callback, + data_loader=self.get_downloader(), + response_strategy=self.get_response_strategy(), + response_formatter=self.get_response_formatter(), + ) + + @lru_cache(maxsize=None) + def get_queue_manager(self): + return PikaQueueManager(self.config.rabbitmq.queues.input, self.config.rabbitmq.queues.output) + + @lru_cache(maxsize=None) + def get_storage(self): + return storages.get_storage(self.config.storage.backend) + + @lru_cache(maxsize=None) + def get_response_strategy(self, storage=None): + return AggregationStorageStrategy( + storage=storage or self.get_storage(), + file_descriptor_manager=self.get_file_descriptor_manager(), + upload_formatter=self.get_upload_formatter(), + ) + + @lru_cache(maxsize=None) + def get_file_descriptor_manager(self): + return FileDescriptorManager( + bucket_name=parse_disjunction_string(self.config.storage.bucket), + operation2file_patterns=self.get_operation2file_patterns(), + ) + + @lru_cache(maxsize=None) + def get_upload_formatter(self): + return {"identity": identity, "projecting": ProjectingUploadFormatter()}[self.config.service.upload_formatter] + + @lru_cache(maxsize=None) + def get_response_formatter(self): + return {"default": DefaultResponseFormatter(), "identity": IdentityResponseFormatter()}[ + self.config.service.response_formatter + ] + + @lru_cache(maxsize=None) + def get_operation2file_patterns(self): + return self.config.service.operations + + @lru_cache(maxsize=None) + def get_downloader(self, storage=None): + return Downloader( + storage=storage or self.get_storage(), + bucket_name=parse_disjunction_string(self.config.storage.bucket), + file_descriptor_manager=self.get_file_descriptor_manager(), + ) + + @staticmethod + @lru_cache(maxsize=None) + def get_pipeline(endpoint): + return ClientPipeline( + RestPacker(), RestDispatcher(endpoint), RestReceiver(), rcompose(RestPickupStreamer(), RestReceiver()) + ) diff --git a/pyinfra/default_objects.py b/pyinfra/default_objects.py index c1d671e..2dec7c6 100644 --- a/pyinfra/default_objects.py +++ b/pyinfra/default_objects.py @@ -1,179 +1,8 @@ -import logging from functools import lru_cache -from funcy import rcompose, omit, merge, lmap, project, identity - -from pyinfra.config import parse_disjunction_string -from pyinfra.exceptions import AnalysisFailure -from pyinfra.file_descriptor_manager import FileDescriptorManager -from pyinfra.queue.consumer import Consumer -from pyinfra.queue.queue_manager.pika_queue_manager import PikaQueueManager -from pyinfra.server.client_pipeline import ClientPipeline -from pyinfra.server.dispatcher.dispatchers.rest import RestDispatcher -from pyinfra.server.interpreter.interpreters.rest_callback import RestPickupStreamer -from pyinfra.server.packer.packers.rest import RestPacker -from pyinfra.server.receiver.receivers.rest import RestReceiver -from pyinfra.storage import storages -from pyinfra.visitor import QueueVisitor -from pyinfra.visitor.downloader import Downloader -from pyinfra.visitor.response_formatter.formatters.default import DefaultResponseFormatter -from pyinfra.visitor.response_formatter.formatters.identity import IdentityResponseFormatter -from pyinfra.visitor.strategies.response.aggregation import ( - AggregationStorageStrategy, - ProjectingUploadFormatter, -) - -logger = logging.getLogger(__name__) +from pyinfra.component_factory import ComponentFactory @lru_cache(maxsize=None) def get_component_factory(config): return ComponentFactory(config) - - -class ComponentFactory: - def __init__(self, config): - self.config = config - - @lru_cache(maxsize=None) - def get_consumer(self, callback=None): - callback = callback or self.get_callback() - return Consumer(self.get_visitor(callback), self.get_queue_manager()) - - @lru_cache(maxsize=None) - def get_callback(self, analysis_base_url=None): - analysis_base_url = analysis_base_url or self.config.rabbitmq.callback.analysis_endpoint - - callback = Callback(analysis_base_url) - - def wrapped(body): - body_repr = project(body, ["dossierId", "fileId", "pages", "images", "operation"]) - logger.info(f"Processing {body_repr}...") - return callback(body) - - return wrapped - - @lru_cache(maxsize=None) - def get_visitor(self, callback): - return QueueVisitor( - callback=callback, - data_loader=self.get_downloader(), - response_strategy=self.get_response_strategy(), - response_formatter=self.get_response_formatter(), - ) - - @lru_cache(maxsize=None) - def get_queue_manager(self): - return PikaQueueManager(self.config.rabbitmq.queues.input, self.config.rabbitmq.queues.output) - - @lru_cache(maxsize=None) - def get_storage(self): - return storages.get_storage(self.config.storage.backend) - - @lru_cache(maxsize=None) - def get_response_strategy(self, storage=None): - return AggregationStorageStrategy( - storage=storage or self.get_storage(), - file_descriptor_manager=self.get_file_descriptor_manager(), - upload_formatter=self.get_upload_formatter(), - ) - - @lru_cache(maxsize=None) - def get_file_descriptor_manager(self): - return FileDescriptorManager( - bucket_name=parse_disjunction_string(self.config.storage.bucket), - operation2file_patterns=self.get_operation2file_patterns(), - ) - - @lru_cache(maxsize=None) - def get_upload_formatter(self): - return {"identity": identity, "projecting": ProjectingUploadFormatter()}[self.config.service.upload_formatter] - - @lru_cache(maxsize=None) - def get_response_formatter(self): - return {"default": DefaultResponseFormatter(), "identity": IdentityResponseFormatter()}[ - self.config.service.response_formatter - ] - - @lru_cache(maxsize=None) - def get_operation2file_patterns(self): - return self.config.service.operations - - @lru_cache(maxsize=None) - def get_downloader(self, storage=None): - return Downloader( - storage=storage or self.get_storage(), - bucket_name=parse_disjunction_string(self.config.storage.bucket), - file_descriptor_manager=self.get_file_descriptor_manager(), - ) - - -class Callback: - def __init__(self, base_url): - self.base_url = base_url - self.endpoint2pipeline = {} - - def __make_endpoint(self, operation): - return f"{self.base_url}/{operation}" - - def __get_pipeline(self, endpoint): - if endpoint in self.endpoint2pipeline: - pipeline = self.endpoint2pipeline[endpoint] - - else: - pipeline = get_pipeline(endpoint) - self.endpoint2pipeline[endpoint] = pipeline - - return pipeline - - @staticmethod - def __run_pipeline(pipeline, body): - """ - TODO: since data and metadata are passed as singletons, there is no buffering and hence no batching happening - within the pipeline. However, the queue acknowledgment logic needs to be changed in order to facilitate - passing non-singletons, to only ack a message, once a response is pulled from the output queue of the - pipeline. Probably the pipeline return value needs to contains the queue message frame (or so), in order for - the queue manager to tell which message to ack. - - TODO: casting list (lmap) on `analysis_response_stream` is a temporary solution, while the client pipeline - operates on singletons ([data], [metadata]). - """ - - def combine_storage_item_metadata_with_queue_message_metadata(body): - return merge(body["metadata"], omit(body, ["data", "metadata"])) - - def remove_queue_message_metadata(result): - metadata = omit(result["metadata"], queue_message_keys(body)) - return {**result, "metadata": metadata} - - def queue_message_keys(body): - return {*body.keys()}.difference({"data", "metadata"}) - - try: - data = body["data"] - metadata = combine_storage_item_metadata_with_queue_message_metadata(body) - analysis_response_stream = pipeline([data], [metadata]) - analysis_response_stream = lmap(remove_queue_message_metadata, analysis_response_stream) - return analysis_response_stream - - except Exception as err: - logger.error(err) - raise AnalysisFailure from err - - def __call__(self, body: dict): - operation = body.get("operation", "submit") - endpoint = self.__make_endpoint(operation) - pipeline = self.__get_pipeline(endpoint) - - try: - logging.debug(f"Requesting analysis from {endpoint}...") - return self.__run_pipeline(pipeline, body) - except AnalysisFailure: - logging.warning(f"Exception caught when calling analysis endpoint {endpoint}.") - - -@lru_cache(maxsize=None) -def get_pipeline(endpoint): - return ClientPipeline( - RestPacker(), RestDispatcher(endpoint), RestReceiver(), rcompose(RestPickupStreamer(), RestReceiver()) - ) diff --git a/src/serve.py b/src/serve.py index 9608c73..6db2ef6 100644 --- a/src/serve.py +++ b/src/serve.py @@ -4,7 +4,7 @@ from multiprocessing import Process from pyinfra.utils.retry import retry from pyinfra.config import CONFIG -from pyinfra.default_objects import ComponentFactory +from pyinfra.component_factory import ComponentFactory from pyinfra.exceptions import ConsumerError from pyinfra.flask import run_probing_webserver, set_up_probing_webserver from pyinfra.utils.banner import show_banner diff --git a/test/config.yaml b/test/config.yaml index ded376f..894a84f 100644 --- a/test/config.yaml +++ b/test/config.yaml @@ -48,5 +48,5 @@ webserver: mock_analysis_endpoint: "http://127.0.0.1:5000" -use_docker_fixture: 0 +use_docker_fixture: 1 logging: 0 \ No newline at end of file diff --git a/test/unit_tests/consumer_test.py b/test/unit_tests/consumer_test.py index 64f2a37..4975b89 100644 --- a/test/unit_tests/consumer_test.py +++ b/test/unit_tests/consumer_test.py @@ -4,7 +4,7 @@ import pytest from funcy import pluck, lflatten from pyinfra.config import CONFIG -from pyinfra.default_objects import ComponentFactory +from pyinfra.component_factory import ComponentFactory from pyinfra.exceptions import ProcessingFailure from pyinfra.visitor.strategies.response.forwarding import ForwardingStrategy from test.utils.storage import pack_for_upload