diff --git a/pyinfra/utils/func.py b/pyinfra/utils/func.py index cf5e308..163f802 100644 --- a/pyinfra/utils/func.py +++ b/pyinfra/utils/func.py @@ -1,6 +1,6 @@ from itertools import starmap, tee -from funcy import curry, compose +from funcy import curry, compose, filter def lift(fn): @@ -36,6 +36,10 @@ def foreach(fn, iterable): fn(itm) +def flift(pred): + return curry(filter)(pred) + + def parallel_map(f1, f2): """Applies functions to a stream in parallel and yields a stream of tuples: parallel_map :: a -> b, a -> c -> [a] -> [(b, c)] diff --git a/pyinfra/visitor/strategies/download/multi.py b/pyinfra/visitor/strategies/download/multi.py index 82b00c9..f09932c 100644 --- a/pyinfra/visitor/strategies/download/multi.py +++ b/pyinfra/visitor/strategies/download/multi.py @@ -1,12 +1,15 @@ import gzip from _operator import itemgetter from copy import deepcopy +from functools import partial +from typing import Collection -from funcy import filter +from funcy import compose from pyinfra.config import parse_disjunction_string, CONFIG from pyinfra.exceptions import InvalidMessage from pyinfra.storage.storage import Storage +from pyinfra.utils.func import flift, lift from pyinfra.visitor.strategies.download.download import DownloadStrategy @@ -15,14 +18,23 @@ class MultiDownloadStrategy(DownloadStrategy): # TODO: pass in bucket name from outside / introduce closure-like abstraction for the bucket self.bucket_name = parse_disjunction_string(CONFIG.storage.bucket) - def download(self, storage: Storage, queue_item_body): - pages = "|".join(map(str, queue_item_body["pages"])) - matches_page = r".*id:(" + pages + r").*" + def get_page_matcher(self, pages): + pages = "|".join(map(str, pages)) + page_matcher = r".*id:(" + pages + r").*" + return page_matcher - object_names = storage.get_all_object_names(self.bucket_name) - object_names = filter(matches_page, object_names) - objects = (storage.get_object(self.bucket_name, objn) for objn in object_names) - objects = map(gzip.decompress, objects) + def get_names_of_objects_by_pages(self, storage, pages: Collection[int]): + matches_page = flift(self.get_page_matcher(pages)) + page_object_names = compose(matches_page, storage.get_all_object_names)(self.bucket_name) + return page_object_names + + def download_and_decompress_object(self, storage, object_names): + download = lift(partial(storage.get_object, self.bucket_name)) + return compose(lift(gzip.decompress), download)(object_names) + + def download(self, storage: Storage, queue_item_body): + page_object_names = self.get_names_of_objects_by_pages(storage, queue_item_body["pages"]) + objects = self.download_and_decompress_object(storage, page_object_names) return objects