refactoring of component factory, callback and client-pipeline getter

This commit is contained in:
Matthias Bisping 2022-06-24 11:06:10 +02:00
parent 6c024e1a78
commit 80f04e5449
6 changed files with 189 additions and 175 deletions

76
pyinfra/callback.py Normal file
View File

@ -0,0 +1,76 @@
import logging
from funcy import merge, omit, lmap
from pyinfra.exceptions import AnalysisFailure
logger = logging.getLogger(__name__)
class Callback:
"""This is the callback that is applied to items pulled from the storage. It forwards these items to an analysis
endpoint.
"""
def __init__(self, base_url, pipeline_factory):
self.base_url = base_url
self.pipeline_factory = pipeline_factory
self.endpoint2pipeline = {}
def __make_endpoint(self, operation):
return f"{self.base_url}/{operation}"
def __get_pipeline(self, endpoint):
if endpoint in self.endpoint2pipeline:
pipeline = self.endpoint2pipeline[endpoint]
else:
pipeline = self.pipeline_factory(endpoint)
self.endpoint2pipeline[endpoint] = pipeline
return pipeline
@staticmethod
def __run_pipeline(pipeline, body):
"""
TODO: Since data and metadata are passed as singletons, there is no buffering and hence no batching happening
within the pipeline. However, the queue acknowledgment logic needs to be changed in order to facilitate
passing non-singletons, to only ack a message, once a response is pulled from the output queue of the
pipeline. Probably the pipeline return value needs to contains the queue message frame (or so), in order for
the queue manager to tell which message to ack.
TODO: casting list (lmap) on `analysis_response_stream` is a temporary solution, while the client pipeline
operates on singletons ([data], [metadata]).
"""
def combine_storage_item_metadata_with_queue_message_metadata(body):
return merge(body["metadata"], omit(body, ["data", "metadata"]))
def remove_queue_message_metadata(result):
metadata = omit(result["metadata"], queue_message_keys(body))
return {**result, "metadata": metadata}
def queue_message_keys(body):
return {*body.keys()}.difference({"data", "metadata"})
try:
data = body["data"]
metadata = combine_storage_item_metadata_with_queue_message_metadata(body)
analysis_response_stream = pipeline([data], [metadata])
analysis_response_stream = lmap(remove_queue_message_metadata, analysis_response_stream)
return analysis_response_stream
except Exception as err:
logger.error(err)
raise AnalysisFailure from err
def __call__(self, body: dict):
operation = body.get("operation", "submit")
endpoint = self.__make_endpoint(operation)
pipeline = self.__get_pipeline(endpoint)
try:
logging.debug(f"Requesting analysis from {endpoint}...")
return self.__run_pipeline(pipeline, body)
except AnalysisFailure:
logging.warning(f"Exception caught when calling analysis endpoint {endpoint}.")

View File

@ -0,0 +1,109 @@
import logging
from functools import lru_cache
from funcy import project, identity, rcompose
from pyinfra.callback import Callback
from pyinfra.config import parse_disjunction_string
from pyinfra.file_descriptor_manager import FileDescriptorManager
from pyinfra.queue.consumer import Consumer
from pyinfra.queue.queue_manager.pika_queue_manager import PikaQueueManager
from pyinfra.server.client_pipeline import ClientPipeline
from pyinfra.server.dispatcher.dispatchers.rest import RestDispatcher
from pyinfra.server.interpreter.interpreters.rest_callback import RestPickupStreamer
from pyinfra.server.packer.packers.rest import RestPacker
from pyinfra.server.receiver.receivers.rest import RestReceiver
from pyinfra.storage import storages
from pyinfra.visitor import QueueVisitor
from pyinfra.visitor.downloader import Downloader
from pyinfra.visitor.response_formatter.formatters.default import DefaultResponseFormatter
from pyinfra.visitor.response_formatter.formatters.identity import IdentityResponseFormatter
from pyinfra.visitor.strategies.response.aggregation import AggregationStorageStrategy, ProjectingUploadFormatter
logger = logging.getLogger(__name__)
class ComponentFactory:
def __init__(self, config):
self.config = config
@lru_cache(maxsize=None)
def get_consumer(self, callback=None):
callback = callback or self.get_callback()
return Consumer(self.get_visitor(callback), self.get_queue_manager())
@lru_cache(maxsize=None)
def get_callback(self, analysis_base_url=None):
analysis_base_url = analysis_base_url or self.config.rabbitmq.callback.analysis_endpoint
callback = Callback(analysis_base_url, self.get_pipeline)
def wrapped(body):
body_repr = project(body, ["dossierId", "fileId", "operation"])
logger.info(f"Processing {body_repr}...")
result = callback(body)
logger.info(f"Completed processing {body_repr}...")
return result
return wrapped
@lru_cache(maxsize=None)
def get_visitor(self, callback):
return QueueVisitor(
callback=callback,
data_loader=self.get_downloader(),
response_strategy=self.get_response_strategy(),
response_formatter=self.get_response_formatter(),
)
@lru_cache(maxsize=None)
def get_queue_manager(self):
return PikaQueueManager(self.config.rabbitmq.queues.input, self.config.rabbitmq.queues.output)
@lru_cache(maxsize=None)
def get_storage(self):
return storages.get_storage(self.config.storage.backend)
@lru_cache(maxsize=None)
def get_response_strategy(self, storage=None):
return AggregationStorageStrategy(
storage=storage or self.get_storage(),
file_descriptor_manager=self.get_file_descriptor_manager(),
upload_formatter=self.get_upload_formatter(),
)
@lru_cache(maxsize=None)
def get_file_descriptor_manager(self):
return FileDescriptorManager(
bucket_name=parse_disjunction_string(self.config.storage.bucket),
operation2file_patterns=self.get_operation2file_patterns(),
)
@lru_cache(maxsize=None)
def get_upload_formatter(self):
return {"identity": identity, "projecting": ProjectingUploadFormatter()}[self.config.service.upload_formatter]
@lru_cache(maxsize=None)
def get_response_formatter(self):
return {"default": DefaultResponseFormatter(), "identity": IdentityResponseFormatter()}[
self.config.service.response_formatter
]
@lru_cache(maxsize=None)
def get_operation2file_patterns(self):
return self.config.service.operations
@lru_cache(maxsize=None)
def get_downloader(self, storage=None):
return Downloader(
storage=storage or self.get_storage(),
bucket_name=parse_disjunction_string(self.config.storage.bucket),
file_descriptor_manager=self.get_file_descriptor_manager(),
)
@staticmethod
@lru_cache(maxsize=None)
def get_pipeline(endpoint):
return ClientPipeline(
RestPacker(), RestDispatcher(endpoint), RestReceiver(), rcompose(RestPickupStreamer(), RestReceiver())
)

View File

@ -1,179 +1,8 @@
import logging
from functools import lru_cache
from funcy import rcompose, omit, merge, lmap, project, identity
from pyinfra.config import parse_disjunction_string
from pyinfra.exceptions import AnalysisFailure
from pyinfra.file_descriptor_manager import FileDescriptorManager
from pyinfra.queue.consumer import Consumer
from pyinfra.queue.queue_manager.pika_queue_manager import PikaQueueManager
from pyinfra.server.client_pipeline import ClientPipeline
from pyinfra.server.dispatcher.dispatchers.rest import RestDispatcher
from pyinfra.server.interpreter.interpreters.rest_callback import RestPickupStreamer
from pyinfra.server.packer.packers.rest import RestPacker
from pyinfra.server.receiver.receivers.rest import RestReceiver
from pyinfra.storage import storages
from pyinfra.visitor import QueueVisitor
from pyinfra.visitor.downloader import Downloader
from pyinfra.visitor.response_formatter.formatters.default import DefaultResponseFormatter
from pyinfra.visitor.response_formatter.formatters.identity import IdentityResponseFormatter
from pyinfra.visitor.strategies.response.aggregation import (
AggregationStorageStrategy,
ProjectingUploadFormatter,
)
logger = logging.getLogger(__name__)
from pyinfra.component_factory import ComponentFactory
@lru_cache(maxsize=None)
def get_component_factory(config):
return ComponentFactory(config)
class ComponentFactory:
def __init__(self, config):
self.config = config
@lru_cache(maxsize=None)
def get_consumer(self, callback=None):
callback = callback or self.get_callback()
return Consumer(self.get_visitor(callback), self.get_queue_manager())
@lru_cache(maxsize=None)
def get_callback(self, analysis_base_url=None):
analysis_base_url = analysis_base_url or self.config.rabbitmq.callback.analysis_endpoint
callback = Callback(analysis_base_url)
def wrapped(body):
body_repr = project(body, ["dossierId", "fileId", "pages", "images", "operation"])
logger.info(f"Processing {body_repr}...")
return callback(body)
return wrapped
@lru_cache(maxsize=None)
def get_visitor(self, callback):
return QueueVisitor(
callback=callback,
data_loader=self.get_downloader(),
response_strategy=self.get_response_strategy(),
response_formatter=self.get_response_formatter(),
)
@lru_cache(maxsize=None)
def get_queue_manager(self):
return PikaQueueManager(self.config.rabbitmq.queues.input, self.config.rabbitmq.queues.output)
@lru_cache(maxsize=None)
def get_storage(self):
return storages.get_storage(self.config.storage.backend)
@lru_cache(maxsize=None)
def get_response_strategy(self, storage=None):
return AggregationStorageStrategy(
storage=storage or self.get_storage(),
file_descriptor_manager=self.get_file_descriptor_manager(),
upload_formatter=self.get_upload_formatter(),
)
@lru_cache(maxsize=None)
def get_file_descriptor_manager(self):
return FileDescriptorManager(
bucket_name=parse_disjunction_string(self.config.storage.bucket),
operation2file_patterns=self.get_operation2file_patterns(),
)
@lru_cache(maxsize=None)
def get_upload_formatter(self):
return {"identity": identity, "projecting": ProjectingUploadFormatter()}[self.config.service.upload_formatter]
@lru_cache(maxsize=None)
def get_response_formatter(self):
return {"default": DefaultResponseFormatter(), "identity": IdentityResponseFormatter()}[
self.config.service.response_formatter
]
@lru_cache(maxsize=None)
def get_operation2file_patterns(self):
return self.config.service.operations
@lru_cache(maxsize=None)
def get_downloader(self, storage=None):
return Downloader(
storage=storage or self.get_storage(),
bucket_name=parse_disjunction_string(self.config.storage.bucket),
file_descriptor_manager=self.get_file_descriptor_manager(),
)
class Callback:
def __init__(self, base_url):
self.base_url = base_url
self.endpoint2pipeline = {}
def __make_endpoint(self, operation):
return f"{self.base_url}/{operation}"
def __get_pipeline(self, endpoint):
if endpoint in self.endpoint2pipeline:
pipeline = self.endpoint2pipeline[endpoint]
else:
pipeline = get_pipeline(endpoint)
self.endpoint2pipeline[endpoint] = pipeline
return pipeline
@staticmethod
def __run_pipeline(pipeline, body):
"""
TODO: since data and metadata are passed as singletons, there is no buffering and hence no batching happening
within the pipeline. However, the queue acknowledgment logic needs to be changed in order to facilitate
passing non-singletons, to only ack a message, once a response is pulled from the output queue of the
pipeline. Probably the pipeline return value needs to contains the queue message frame (or so), in order for
the queue manager to tell which message to ack.
TODO: casting list (lmap) on `analysis_response_stream` is a temporary solution, while the client pipeline
operates on singletons ([data], [metadata]).
"""
def combine_storage_item_metadata_with_queue_message_metadata(body):
return merge(body["metadata"], omit(body, ["data", "metadata"]))
def remove_queue_message_metadata(result):
metadata = omit(result["metadata"], queue_message_keys(body))
return {**result, "metadata": metadata}
def queue_message_keys(body):
return {*body.keys()}.difference({"data", "metadata"})
try:
data = body["data"]
metadata = combine_storage_item_metadata_with_queue_message_metadata(body)
analysis_response_stream = pipeline([data], [metadata])
analysis_response_stream = lmap(remove_queue_message_metadata, analysis_response_stream)
return analysis_response_stream
except Exception as err:
logger.error(err)
raise AnalysisFailure from err
def __call__(self, body: dict):
operation = body.get("operation", "submit")
endpoint = self.__make_endpoint(operation)
pipeline = self.__get_pipeline(endpoint)
try:
logging.debug(f"Requesting analysis from {endpoint}...")
return self.__run_pipeline(pipeline, body)
except AnalysisFailure:
logging.warning(f"Exception caught when calling analysis endpoint {endpoint}.")
@lru_cache(maxsize=None)
def get_pipeline(endpoint):
return ClientPipeline(
RestPacker(), RestDispatcher(endpoint), RestReceiver(), rcompose(RestPickupStreamer(), RestReceiver())
)

View File

@ -4,7 +4,7 @@ from multiprocessing import Process
from pyinfra.utils.retry import retry
from pyinfra.config import CONFIG
from pyinfra.default_objects import ComponentFactory
from pyinfra.component_factory import ComponentFactory
from pyinfra.exceptions import ConsumerError
from pyinfra.flask import run_probing_webserver, set_up_probing_webserver
from pyinfra.utils.banner import show_banner

View File

@ -48,5 +48,5 @@ webserver:
mock_analysis_endpoint: "http://127.0.0.1:5000"
use_docker_fixture: 0
use_docker_fixture: 1
logging: 0

View File

@ -4,7 +4,7 @@ import pytest
from funcy import pluck, lflatten
from pyinfra.config import CONFIG
from pyinfra.default_objects import ComponentFactory
from pyinfra.component_factory import ComponentFactory
from pyinfra.exceptions import ProcessingFailure
from pyinfra.visitor.strategies.response.forwarding import ForwardingStrategy
from test.utils.storage import pack_for_upload