refactoring; single dl strat now uses file descriptor manager
This commit is contained in:
parent
23876501dc
commit
bb6b28bb4e
@ -15,7 +15,8 @@ from pyinfra.storage import storages
|
||||
from pyinfra.visitor import QueueVisitor
|
||||
from pyinfra.visitor.response_formatter.formatters.default import DefaultResponseFormatter
|
||||
from pyinfra.visitor.response_formatter.formatters.identity import IdentityResponseFormatter
|
||||
from pyinfra.visitor.strategies.download.multi import MultiDownloadStrategy, FileDescriptorManager
|
||||
from pyinfra.visitor.strategies.download.multi import MultiDownloadStrategy
|
||||
from pyinfra.file_descriptor_manager import FileDescriptorManager
|
||||
from pyinfra.visitor.strategies.download.single import SingleDownloadStrategy
|
||||
from pyinfra.visitor.strategies.response.aggregation import AggregationStorageStrategy
|
||||
|
||||
@ -41,6 +42,7 @@ class ComponentFactory:
|
||||
return QueueVisitor(
|
||||
storage=self.get_storage(),
|
||||
callback=callback,
|
||||
download_strategy=self.get_download_strategy(),
|
||||
response_strategy=self.get_response_strategy(),
|
||||
response_formatter=self.get_response_formatter(),
|
||||
)
|
||||
@ -80,20 +82,28 @@ class ComponentFactory:
|
||||
def get_operation2file_patterns(self):
|
||||
return self.config.service.operations
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_file_descriptor_manager(self):
|
||||
return FileDescriptorManager(self.get_operation2file_patterns())
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_download_strategy(self, download_strategy_type=None):
|
||||
download_strategies = {
|
||||
"single": SingleDownloadStrategy(),
|
||||
"multi": MultiDownloadStrategy(self.get_file_descriptor_manager()),
|
||||
"single": self.get_single_download_strategy(),
|
||||
"multi": self.get_multi_download_strategy(),
|
||||
}
|
||||
return download_strategies.get(
|
||||
download_strategy_type or self.config.service.download_strategy, SingleDownloadStrategy()
|
||||
download_strategy_type or self.config.service.download_strategy, self.get_single_download_strategy()
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_single_download_strategy(self):
|
||||
return SingleDownloadStrategy(self.get_file_descriptor_manager())
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_multi_download_strategy(self):
|
||||
return MultiDownloadStrategy(self.get_file_descriptor_manager())
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_file_descriptor_manager(self):
|
||||
return FileDescriptorManager(self.get_operation2file_patterns())
|
||||
|
||||
|
||||
class Callback:
|
||||
def __init__(self, base_url):
|
||||
|
||||
64
pyinfra/file_descriptor_manager.py
Normal file
64
pyinfra/file_descriptor_manager.py
Normal file
@ -0,0 +1,64 @@
|
||||
import os
|
||||
from _operator import itemgetter
|
||||
|
||||
from funcy import project
|
||||
|
||||
from pyinfra.config import parse_disjunction_string, CONFIG
|
||||
|
||||
|
||||
class FileDescriptorManager:
|
||||
def __init__(self, operation2file_patterns):
|
||||
# TODO: pass in bucket name from outside / introduce closure-like abstraction for the bucket
|
||||
self.bucket_name = parse_disjunction_string(CONFIG.storage.bucket)
|
||||
self.operation2file_patterns = operation2file_patterns
|
||||
|
||||
def get_object_name(self, queue_item_body: dict):
|
||||
|
||||
file_descriptor = self.build_file_descriptor(queue_item_body)
|
||||
file_descriptor["pages"] = [queue_item_body.get("id", 0)]
|
||||
|
||||
object_name = self.__build_matcher(file_descriptor)
|
||||
|
||||
return object_name
|
||||
|
||||
def build_file_descriptor(self, queue_item_body, end="input"):
|
||||
operation = queue_item_body.get("operation", "default")
|
||||
|
||||
file_pattern = self.operation2file_patterns[operation][end]
|
||||
|
||||
file_descriptor = {
|
||||
**project(queue_item_body, ["dossierId", "fileId", "pages"]),
|
||||
"pages": queue_item_body.get("pages", []),
|
||||
"extension": file_pattern["extension"],
|
||||
"subdir": file_pattern["subdir"],
|
||||
}
|
||||
return file_descriptor
|
||||
|
||||
def build_matcher(self, queue_item_body):
|
||||
file_descriptor = self.build_file_descriptor(queue_item_body)
|
||||
return self.__build_matcher(file_descriptor)
|
||||
|
||||
def __build_matcher(self, file_descriptor):
|
||||
def make_filename(file_id, subdir, suffix):
|
||||
return os.path.join(file_id, subdir, suffix) if subdir else f"{file_id}.{suffix}"
|
||||
|
||||
dossier_id, file_id, subdir, pages, extension = itemgetter(
|
||||
"dossierId", "fileId", "subdir", "pages", "extension"
|
||||
)(file_descriptor)
|
||||
|
||||
if len(pages) > 1 and subdir:
|
||||
page_re = "id:(" + "|".join(map(str, pages)) + ")."
|
||||
elif len(pages) == 1 and subdir:
|
||||
page_re = f"id:{pages[0]}."
|
||||
else:
|
||||
page_re = ""
|
||||
|
||||
matcher = os.path.join(dossier_id, make_filename(file_id, subdir, page_re + extension))
|
||||
|
||||
return matcher
|
||||
|
||||
def get_object_descriptor(self, queue_item_body):
|
||||
return {
|
||||
"bucket_name": parse_disjunction_string(CONFIG.storage.bucket),
|
||||
"object_name": self.get_object_name(queue_item_body),
|
||||
}
|
||||
@ -1,38 +1,8 @@
|
||||
import abc
|
||||
import logging
|
||||
|
||||
from pyinfra.config import parse_disjunction_string, CONFIG
|
||||
from pyinfra.exceptions import DataLoadingFailure
|
||||
from pyinfra.utils.encoding import decompress
|
||||
from pyinfra.file_descriptor_manager import FileDescriptorManager
|
||||
|
||||
|
||||
class DownloadStrategy(abc.ABC):
|
||||
def _load_data(self, storage, queue_item_body):
|
||||
object_descriptor = self.get_object_descriptor(queue_item_body)
|
||||
logging.debug(f"Downloading {object_descriptor}...")
|
||||
data = self.__download(storage, object_descriptor)
|
||||
logging.debug(f"Downloaded {object_descriptor}.")
|
||||
assert isinstance(data, bytes)
|
||||
data = decompress(data)
|
||||
return [data]
|
||||
|
||||
@staticmethod
|
||||
def __download(storage, object_descriptor):
|
||||
try:
|
||||
data = storage.get_object(**object_descriptor)
|
||||
except Exception as err:
|
||||
logging.warning(f"Loading data from storage failed for {object_descriptor}.")
|
||||
raise DataLoadingFailure from err
|
||||
|
||||
return data
|
||||
|
||||
def get_object_descriptor(self, body):
|
||||
return {
|
||||
"bucket_name": parse_disjunction_string(CONFIG.storage.bucket),
|
||||
"object_name": self.get_object_name(body),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def get_object_name(body: dict):
|
||||
raise NotImplementedError
|
||||
def __init__(self, file_descriptor_manager: FileDescriptorManager):
|
||||
self.file_descriptor_manager = file_descriptor_manager
|
||||
|
||||
@ -1,76 +1,20 @@
|
||||
import os
|
||||
from _operator import itemgetter
|
||||
from functools import partial
|
||||
|
||||
from funcy import compose, project
|
||||
from funcy import compose
|
||||
|
||||
from pyinfra.config import parse_disjunction_string, CONFIG
|
||||
from pyinfra.file_descriptor_manager import FileDescriptorManager
|
||||
from pyinfra.storage.storage import Storage
|
||||
from pyinfra.utils.encoding import decompress
|
||||
from pyinfra.utils.func import flift
|
||||
from pyinfra.visitor.strategies.download.download import DownloadStrategy
|
||||
|
||||
|
||||
class FileDescriptorManager:
|
||||
def __init__(self, operation2file_patterns):
|
||||
# TODO: pass in bucket name from outside / introduce closure-like abstraction for the bucket
|
||||
self.bucket_name = parse_disjunction_string(CONFIG.storage.bucket)
|
||||
self.operation2file_patterns = operation2file_patterns
|
||||
|
||||
def get_object_name(self, queue_item_body: dict):
|
||||
|
||||
file_descriptor = self.build_file_descriptor(queue_item_body)
|
||||
file_descriptor["pages"] = [queue_item_body.get("id", 0)]
|
||||
|
||||
object_name = self.__build_matcher(file_descriptor)
|
||||
|
||||
return object_name
|
||||
|
||||
def build_file_descriptor(self, queue_item_body, end="input"):
|
||||
operation = queue_item_body.get("operation", "default")
|
||||
|
||||
file_pattern = self.operation2file_patterns[operation][end]
|
||||
|
||||
file_descriptor = {
|
||||
**project(queue_item_body, ["dossierId", "fileId", "pages"]),
|
||||
"pages": queue_item_body.get("pages", []),
|
||||
"extension": file_pattern["extension"],
|
||||
"subdir": file_pattern["subdir"],
|
||||
}
|
||||
return file_descriptor
|
||||
|
||||
def build_matcher(self, queue_item_body):
|
||||
file_descriptor = self.build_file_descriptor(queue_item_body)
|
||||
return self.__build_matcher(file_descriptor)
|
||||
|
||||
def __build_matcher(self, file_descriptor):
|
||||
|
||||
dossier_id, file_id, subdir, pages, extension = itemgetter(
|
||||
"dossierId", "fileId", "subdir", "pages", "extension"
|
||||
)(file_descriptor)
|
||||
|
||||
if len(pages) > 1:
|
||||
page_re = "id:(" + "|".join(map(str, pages)) + ")."
|
||||
elif len(pages) == 1:
|
||||
page_re = f"id:{pages[0]}."
|
||||
else:
|
||||
page_re = ""
|
||||
|
||||
matcher = os.path.join(dossier_id, file_id, subdir, page_re + extension)
|
||||
|
||||
return matcher
|
||||
|
||||
def get_object_descriptor(self, queue_item_body):
|
||||
return {
|
||||
"bucket_name": parse_disjunction_string(CONFIG.storage.bucket),
|
||||
"object_name": self.get_object_name(queue_item_body),
|
||||
}
|
||||
|
||||
|
||||
class MultiDownloadStrategy:
|
||||
class MultiDownloadStrategy(DownloadStrategy):
|
||||
def __init__(self, file_descriptor_manager: FileDescriptorManager):
|
||||
# TODO: pass in bucket name from outside / introduce closure-like abstraction for the bucket
|
||||
self.bucket_name = parse_disjunction_string(CONFIG.storage.bucket)
|
||||
self.file_descriptor_manager = file_descriptor_manager
|
||||
super().__init__(file_descriptor_manager=file_descriptor_manager)
|
||||
|
||||
def __call__(self, storage, queue_item_body):
|
||||
return self.download(storage, queue_item_body)
|
||||
|
||||
@ -1,25 +1,36 @@
|
||||
from _operator import itemgetter
|
||||
from copy import deepcopy
|
||||
import logging
|
||||
|
||||
from pyinfra.config import CONFIG
|
||||
from pyinfra.exceptions import DataLoadingFailure
|
||||
from pyinfra.file_descriptor_manager import FileDescriptorManager
|
||||
from pyinfra.utils.encoding import decompress
|
||||
from pyinfra.visitor.strategies.download.download import DownloadStrategy
|
||||
|
||||
|
||||
class SingleDownloadStrategy(DownloadStrategy):
|
||||
def __init__(self, file_descriptor_manager: FileDescriptorManager):
|
||||
super().__init__(file_descriptor_manager=file_descriptor_manager)
|
||||
|
||||
def download(self, storage, queue_item_body):
|
||||
return self._load_data(storage, queue_item_body)
|
||||
|
||||
def _load_data(self, storage, queue_item_body):
|
||||
object_descriptor = self.file_descriptor_manager.get_object_descriptor(queue_item_body)
|
||||
logging.debug(f"Downloading {object_descriptor}...")
|
||||
data = self.__download(storage, object_descriptor)
|
||||
logging.debug(f"Downloaded {object_descriptor}.")
|
||||
assert isinstance(data, bytes)
|
||||
data = decompress(data)
|
||||
return [data]
|
||||
|
||||
@staticmethod
|
||||
def get_object_name(body: dict):
|
||||
def __download(storage, object_descriptor):
|
||||
try:
|
||||
data = storage.get_object(**object_descriptor)
|
||||
except Exception as err:
|
||||
logging.warning(f"Loading data from storage failed for {object_descriptor}.")
|
||||
raise DataLoadingFailure from err
|
||||
|
||||
# TODO: deepcopy still necessary?
|
||||
body = deepcopy(body)
|
||||
|
||||
dossier_id, file_id = itemgetter("dossierId", "fileId")(body)
|
||||
|
||||
object_name = f"{dossier_id}/{file_id}.{CONFIG.service.target_file_extension}"
|
||||
|
||||
return object_name
|
||||
return data
|
||||
|
||||
def __call__(self, storage, queue_item_body):
|
||||
return self.download(storage, queue_item_body)
|
||||
|
||||
@ -9,7 +9,6 @@ from pyinfra.visitor.response_formatter.formatters.identity import IdentityRespo
|
||||
from pyinfra.visitor.strategies.blob_parsing.blob_parsing import BlobParsingStrategy
|
||||
from pyinfra.visitor.strategies.blob_parsing.dynamic import DynamicParsingStrategy
|
||||
from pyinfra.visitor.strategies.download.download import DownloadStrategy
|
||||
from pyinfra.visitor.strategies.download.single import SingleDownloadStrategy
|
||||
from pyinfra.visitor.strategies.response.response import ResponseStrategy
|
||||
from pyinfra.visitor.strategies.response.storage import StorageStrategy
|
||||
from pyinfra.visitor.utils import standardize
|
||||
@ -20,7 +19,7 @@ class QueueVisitor:
|
||||
self,
|
||||
storage: Storage,
|
||||
callback: Callable,
|
||||
download_strategy: DownloadStrategy = None,
|
||||
download_strategy: DownloadStrategy,
|
||||
parsing_strategy: BlobParsingStrategy = None,
|
||||
response_strategy: ResponseStrategy = None,
|
||||
response_formatter: ResponseFormatter = None,
|
||||
@ -39,7 +38,7 @@ class QueueVisitor:
|
||||
"""
|
||||
self.storage = storage
|
||||
self.callback = callback
|
||||
self.download_strategy = download_strategy or SingleDownloadStrategy()
|
||||
self.download_strategy = download_strategy
|
||||
self.parsing_strategy = parsing_strategy or DynamicParsingStrategy()
|
||||
self.response_strategy = response_strategy or StorageStrategy(storage)
|
||||
self.response_formatter = response_formatter or IdentityResponseFormatter()
|
||||
|
||||
@ -80,8 +80,8 @@ from test.config import CONFIG as TEST_CONFIG
|
||||
@pytest.mark.parametrize(
|
||||
"many_to_n",
|
||||
[
|
||||
# False,
|
||||
True,
|
||||
False,
|
||||
# True,
|
||||
],
|
||||
)
|
||||
def test_serving(server_process, bucket_name, components, targets, data_message_pairs, n_items, many_to_n):
|
||||
@ -134,6 +134,7 @@ def upload_data_to_storage_and_publish_requests_to_queue(
|
||||
|
||||
# TODO: refactor; too many params
|
||||
def upload_data_to_storage_and_publish_request_to_queue(storage, queue_manager, data, message, file_descriptor_manager):
|
||||
print(file_descriptor_manager.get_object_descriptor(message))
|
||||
storage.put_object(**file_descriptor_manager.get_object_descriptor(message), data=compress(data))
|
||||
queue_manager.publish_request(message)
|
||||
|
||||
@ -207,10 +208,11 @@ def components_type(request):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_components(url, many_to_n):
|
||||
def real_components(url, download_strategy_type):
|
||||
|
||||
CONFIG["service"]["operations"] = TEST_CONFIG.service.operations
|
||||
CONFIG["service"]["response_formatter"] = TEST_CONFIG.service.response_formatter
|
||||
CONFIG["service"]["download_strategy"] = download_strategy_type
|
||||
|
||||
component_factory = get_component_factory(CONFIG)
|
||||
|
||||
@ -220,12 +222,17 @@ def real_components(url, many_to_n):
|
||||
storage = component_factory.get_storage()
|
||||
file_descriptor_manager = component_factory.get_file_descriptor_manager()
|
||||
|
||||
download_strategy = component_factory.get_download_strategy("multi" if many_to_n else "single")
|
||||
download_strategy = component_factory.get_download_strategy()
|
||||
|
||||
consumer.visitor.download_strategy = download_strategy
|
||||
# consumer.visitor.download_strategy = download_strategy
|
||||
return storage, queue_manager, consumer, download_strategy, file_descriptor_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def download_strategy_type(many_to_n):
|
||||
return "multi" if many_to_n else "single"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_components(url, queue_manager, storage, many_to_n):
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user