diff --git a/pyinfra/default_objects.py b/pyinfra/default_objects.py index aa1e984..bb8ad2c 100644 --- a/pyinfra/default_objects.py +++ b/pyinfra/default_objects.py @@ -15,7 +15,8 @@ from pyinfra.storage import storages from pyinfra.visitor import QueueVisitor from pyinfra.visitor.response_formatter.formatters.default import DefaultResponseFormatter from pyinfra.visitor.response_formatter.formatters.identity import IdentityResponseFormatter -from pyinfra.visitor.strategies.download.multi import MultiDownloadStrategy, FileDescriptorManager +from pyinfra.visitor.strategies.download.multi import MultiDownloadStrategy +from pyinfra.file_descriptor_manager import FileDescriptorManager from pyinfra.visitor.strategies.download.single import SingleDownloadStrategy from pyinfra.visitor.strategies.response.aggregation import AggregationStorageStrategy @@ -41,6 +42,7 @@ class ComponentFactory: return QueueVisitor( storage=self.get_storage(), callback=callback, + download_strategy=self.get_download_strategy(), response_strategy=self.get_response_strategy(), response_formatter=self.get_response_formatter(), ) @@ -80,20 +82,28 @@ class ComponentFactory: def get_operation2file_patterns(self): return self.config.service.operations - @lru_cache(maxsize=None) - def get_file_descriptor_manager(self): - return FileDescriptorManager(self.get_operation2file_patterns()) - @lru_cache(maxsize=None) def get_download_strategy(self, download_strategy_type=None): download_strategies = { - "single": SingleDownloadStrategy(), - "multi": MultiDownloadStrategy(self.get_file_descriptor_manager()), + "single": self.get_single_download_strategy(), + "multi": self.get_multi_download_strategy(), } return download_strategies.get( - download_strategy_type or self.config.service.download_strategy, SingleDownloadStrategy() + download_strategy_type or self.config.service.download_strategy, self.get_single_download_strategy() ) + @lru_cache(maxsize=None) + def get_single_download_strategy(self): + return SingleDownloadStrategy(self.get_file_descriptor_manager()) + + @lru_cache(maxsize=None) + def get_multi_download_strategy(self): + return MultiDownloadStrategy(self.get_file_descriptor_manager()) + + @lru_cache(maxsize=None) + def get_file_descriptor_manager(self): + return FileDescriptorManager(self.get_operation2file_patterns()) + class Callback: def __init__(self, base_url): diff --git a/pyinfra/file_descriptor_manager.py b/pyinfra/file_descriptor_manager.py new file mode 100644 index 0000000..21581d1 --- /dev/null +++ b/pyinfra/file_descriptor_manager.py @@ -0,0 +1,64 @@ +import os +from _operator import itemgetter + +from funcy import project + +from pyinfra.config import parse_disjunction_string, CONFIG + + +class FileDescriptorManager: + def __init__(self, operation2file_patterns): + # TODO: pass in bucket name from outside / introduce closure-like abstraction for the bucket + self.bucket_name = parse_disjunction_string(CONFIG.storage.bucket) + self.operation2file_patterns = operation2file_patterns + + def get_object_name(self, queue_item_body: dict): + + file_descriptor = self.build_file_descriptor(queue_item_body) + file_descriptor["pages"] = [queue_item_body.get("id", 0)] + + object_name = self.__build_matcher(file_descriptor) + + return object_name + + def build_file_descriptor(self, queue_item_body, end="input"): + operation = queue_item_body.get("operation", "default") + + file_pattern = self.operation2file_patterns[operation][end] + + file_descriptor = { + **project(queue_item_body, ["dossierId", "fileId", "pages"]), + "pages": queue_item_body.get("pages", []), + "extension": file_pattern["extension"], + "subdir": file_pattern["subdir"], + } + return file_descriptor + + def build_matcher(self, queue_item_body): + file_descriptor = self.build_file_descriptor(queue_item_body) + return self.__build_matcher(file_descriptor) + + def __build_matcher(self, file_descriptor): + def make_filename(file_id, subdir, suffix): + return os.path.join(file_id, subdir, suffix) if subdir else f"{file_id}.{suffix}" + + dossier_id, file_id, subdir, pages, extension = itemgetter( + "dossierId", "fileId", "subdir", "pages", "extension" + )(file_descriptor) + + if len(pages) > 1 and subdir: + page_re = "id:(" + "|".join(map(str, pages)) + ")." + elif len(pages) == 1 and subdir: + page_re = f"id:{pages[0]}." + else: + page_re = "" + + matcher = os.path.join(dossier_id, make_filename(file_id, subdir, page_re + extension)) + + return matcher + + def get_object_descriptor(self, queue_item_body): + return { + "bucket_name": parse_disjunction_string(CONFIG.storage.bucket), + "object_name": self.get_object_name(queue_item_body), + } diff --git a/pyinfra/visitor/strategies/download/download.py b/pyinfra/visitor/strategies/download/download.py index 260ab84..1f6c09b 100644 --- a/pyinfra/visitor/strategies/download/download.py +++ b/pyinfra/visitor/strategies/download/download.py @@ -1,38 +1,8 @@ import abc -import logging -from pyinfra.config import parse_disjunction_string, CONFIG -from pyinfra.exceptions import DataLoadingFailure -from pyinfra.utils.encoding import decompress +from pyinfra.file_descriptor_manager import FileDescriptorManager class DownloadStrategy(abc.ABC): - def _load_data(self, storage, queue_item_body): - object_descriptor = self.get_object_descriptor(queue_item_body) - logging.debug(f"Downloading {object_descriptor}...") - data = self.__download(storage, object_descriptor) - logging.debug(f"Downloaded {object_descriptor}.") - assert isinstance(data, bytes) - data = decompress(data) - return [data] - - @staticmethod - def __download(storage, object_descriptor): - try: - data = storage.get_object(**object_descriptor) - except Exception as err: - logging.warning(f"Loading data from storage failed for {object_descriptor}.") - raise DataLoadingFailure from err - - return data - - def get_object_descriptor(self, body): - return { - "bucket_name": parse_disjunction_string(CONFIG.storage.bucket), - "object_name": self.get_object_name(body), - } - - @staticmethod - @abc.abstractmethod - def get_object_name(body: dict): - raise NotImplementedError + def __init__(self, file_descriptor_manager: FileDescriptorManager): + self.file_descriptor_manager = file_descriptor_manager diff --git a/pyinfra/visitor/strategies/download/multi.py b/pyinfra/visitor/strategies/download/multi.py index 5066532..41b92b8 100644 --- a/pyinfra/visitor/strategies/download/multi.py +++ b/pyinfra/visitor/strategies/download/multi.py @@ -1,76 +1,20 @@ -import os -from _operator import itemgetter from functools import partial -from funcy import compose, project +from funcy import compose from pyinfra.config import parse_disjunction_string, CONFIG +from pyinfra.file_descriptor_manager import FileDescriptorManager from pyinfra.storage.storage import Storage from pyinfra.utils.encoding import decompress from pyinfra.utils.func import flift +from pyinfra.visitor.strategies.download.download import DownloadStrategy -class FileDescriptorManager: - def __init__(self, operation2file_patterns): - # TODO: pass in bucket name from outside / introduce closure-like abstraction for the bucket - self.bucket_name = parse_disjunction_string(CONFIG.storage.bucket) - self.operation2file_patterns = operation2file_patterns - - def get_object_name(self, queue_item_body: dict): - - file_descriptor = self.build_file_descriptor(queue_item_body) - file_descriptor["pages"] = [queue_item_body.get("id", 0)] - - object_name = self.__build_matcher(file_descriptor) - - return object_name - - def build_file_descriptor(self, queue_item_body, end="input"): - operation = queue_item_body.get("operation", "default") - - file_pattern = self.operation2file_patterns[operation][end] - - file_descriptor = { - **project(queue_item_body, ["dossierId", "fileId", "pages"]), - "pages": queue_item_body.get("pages", []), - "extension": file_pattern["extension"], - "subdir": file_pattern["subdir"], - } - return file_descriptor - - def build_matcher(self, queue_item_body): - file_descriptor = self.build_file_descriptor(queue_item_body) - return self.__build_matcher(file_descriptor) - - def __build_matcher(self, file_descriptor): - - dossier_id, file_id, subdir, pages, extension = itemgetter( - "dossierId", "fileId", "subdir", "pages", "extension" - )(file_descriptor) - - if len(pages) > 1: - page_re = "id:(" + "|".join(map(str, pages)) + ")." - elif len(pages) == 1: - page_re = f"id:{pages[0]}." - else: - page_re = "" - - matcher = os.path.join(dossier_id, file_id, subdir, page_re + extension) - - return matcher - - def get_object_descriptor(self, queue_item_body): - return { - "bucket_name": parse_disjunction_string(CONFIG.storage.bucket), - "object_name": self.get_object_name(queue_item_body), - } - - -class MultiDownloadStrategy: +class MultiDownloadStrategy(DownloadStrategy): def __init__(self, file_descriptor_manager: FileDescriptorManager): # TODO: pass in bucket name from outside / introduce closure-like abstraction for the bucket self.bucket_name = parse_disjunction_string(CONFIG.storage.bucket) - self.file_descriptor_manager = file_descriptor_manager + super().__init__(file_descriptor_manager=file_descriptor_manager) def __call__(self, storage, queue_item_body): return self.download(storage, queue_item_body) diff --git a/pyinfra/visitor/strategies/download/single.py b/pyinfra/visitor/strategies/download/single.py index 42d8073..50e0a5c 100644 --- a/pyinfra/visitor/strategies/download/single.py +++ b/pyinfra/visitor/strategies/download/single.py @@ -1,25 +1,36 @@ -from _operator import itemgetter -from copy import deepcopy +import logging -from pyinfra.config import CONFIG +from pyinfra.exceptions import DataLoadingFailure +from pyinfra.file_descriptor_manager import FileDescriptorManager +from pyinfra.utils.encoding import decompress from pyinfra.visitor.strategies.download.download import DownloadStrategy class SingleDownloadStrategy(DownloadStrategy): + def __init__(self, file_descriptor_manager: FileDescriptorManager): + super().__init__(file_descriptor_manager=file_descriptor_manager) + def download(self, storage, queue_item_body): return self._load_data(storage, queue_item_body) + def _load_data(self, storage, queue_item_body): + object_descriptor = self.file_descriptor_manager.get_object_descriptor(queue_item_body) + logging.debug(f"Downloading {object_descriptor}...") + data = self.__download(storage, object_descriptor) + logging.debug(f"Downloaded {object_descriptor}.") + assert isinstance(data, bytes) + data = decompress(data) + return [data] + @staticmethod - def get_object_name(body: dict): + def __download(storage, object_descriptor): + try: + data = storage.get_object(**object_descriptor) + except Exception as err: + logging.warning(f"Loading data from storage failed for {object_descriptor}.") + raise DataLoadingFailure from err - # TODO: deepcopy still necessary? - body = deepcopy(body) - - dossier_id, file_id = itemgetter("dossierId", "fileId")(body) - - object_name = f"{dossier_id}/{file_id}.{CONFIG.service.target_file_extension}" - - return object_name + return data def __call__(self, storage, queue_item_body): return self.download(storage, queue_item_body) diff --git a/pyinfra/visitor/visitor.py b/pyinfra/visitor/visitor.py index b6bc851..32d1d23 100644 --- a/pyinfra/visitor/visitor.py +++ b/pyinfra/visitor/visitor.py @@ -9,7 +9,6 @@ from pyinfra.visitor.response_formatter.formatters.identity import IdentityRespo from pyinfra.visitor.strategies.blob_parsing.blob_parsing import BlobParsingStrategy from pyinfra.visitor.strategies.blob_parsing.dynamic import DynamicParsingStrategy from pyinfra.visitor.strategies.download.download import DownloadStrategy -from pyinfra.visitor.strategies.download.single import SingleDownloadStrategy from pyinfra.visitor.strategies.response.response import ResponseStrategy from pyinfra.visitor.strategies.response.storage import StorageStrategy from pyinfra.visitor.utils import standardize @@ -20,7 +19,7 @@ class QueueVisitor: self, storage: Storage, callback: Callable, - download_strategy: DownloadStrategy = None, + download_strategy: DownloadStrategy, parsing_strategy: BlobParsingStrategy = None, response_strategy: ResponseStrategy = None, response_formatter: ResponseFormatter = None, @@ -39,7 +38,7 @@ class QueueVisitor: """ self.storage = storage self.callback = callback - self.download_strategy = download_strategy or SingleDownloadStrategy() + self.download_strategy = download_strategy self.parsing_strategy = parsing_strategy or DynamicParsingStrategy() self.response_strategy = response_strategy or StorageStrategy(storage) self.response_formatter = response_formatter or IdentityResponseFormatter() diff --git a/test/integration_tests/serve_test.py b/test/integration_tests/serve_test.py index 4465a54..f64a27a 100644 --- a/test/integration_tests/serve_test.py +++ b/test/integration_tests/serve_test.py @@ -80,8 +80,8 @@ from test.config import CONFIG as TEST_CONFIG @pytest.mark.parametrize( "many_to_n", [ - # False, - True, + False, + # True, ], ) def test_serving(server_process, bucket_name, components, targets, data_message_pairs, n_items, many_to_n): @@ -134,6 +134,7 @@ def upload_data_to_storage_and_publish_requests_to_queue( # TODO: refactor; too many params def upload_data_to_storage_and_publish_request_to_queue(storage, queue_manager, data, message, file_descriptor_manager): + print(file_descriptor_manager.get_object_descriptor(message)) storage.put_object(**file_descriptor_manager.get_object_descriptor(message), data=compress(data)) queue_manager.publish_request(message) @@ -207,10 +208,11 @@ def components_type(request): @pytest.fixture -def real_components(url, many_to_n): +def real_components(url, download_strategy_type): CONFIG["service"]["operations"] = TEST_CONFIG.service.operations CONFIG["service"]["response_formatter"] = TEST_CONFIG.service.response_formatter + CONFIG["service"]["download_strategy"] = download_strategy_type component_factory = get_component_factory(CONFIG) @@ -220,12 +222,17 @@ def real_components(url, many_to_n): storage = component_factory.get_storage() file_descriptor_manager = component_factory.get_file_descriptor_manager() - download_strategy = component_factory.get_download_strategy("multi" if many_to_n else "single") + download_strategy = component_factory.get_download_strategy() - consumer.visitor.download_strategy = download_strategy + # consumer.visitor.download_strategy = download_strategy return storage, queue_manager, consumer, download_strategy, file_descriptor_manager +@pytest.fixture +def download_strategy_type(many_to_n): + return "multi" if many_to_n else "single" + + @pytest.fixture def test_components(url, queue_manager, storage, many_to_n):