refactoring: introduced input- and output-file specific methods to file descr mngr

This commit is contained in:
Matthias Bisping 2022-06-23 10:59:57 +02:00
parent ecced37150
commit 2c80d7cec0
6 changed files with 59 additions and 38 deletions

View File

@ -3,6 +3,7 @@ from functools import lru_cache
from funcy import rcompose, omit, merge, lmap, project
from pyinfra.config import parse_disjunction_string
from pyinfra.exceptions import AnalysisFailure
from pyinfra.queue.consumer import Consumer
from pyinfra.queue.queue_manager.pika_queue_manager import PikaQueueManager
@ -89,7 +90,8 @@ class ComponentFactory:
"multi": self.get_multi_download_strategy(),
}
return download_strategies.get(
download_strategy_type or self.config.service.download_strategy, self.get_single_download_strategy()
download_strategy_type or self.config.service.download_strategy,
self.get_single_download_strategy(),
)
@lru_cache(maxsize=None)
@ -102,7 +104,10 @@ class ComponentFactory:
@lru_cache(maxsize=None)
def get_file_descriptor_manager(self):
return FileDescriptorManager(self.get_operation2file_patterns())
return FileDescriptorManager(
operation2file_patterns=self.get_operation2file_patterns(),
bucket_name=parse_disjunction_string(self.config.storage.bucket),
)
class Callback:

View File

@ -3,42 +3,29 @@ from _operator import itemgetter
from funcy import project
from pyinfra.config import parse_disjunction_string, CONFIG
class FileDescriptorManager:
def __init__(self, operation2file_patterns):
# TODO: pass in bucket name from outside / introduce closure-like abstraction for the bucket
self.bucket_name = parse_disjunction_string(CONFIG.storage.bucket)
def __init__(self, operation2file_patterns, bucket_name):
self.operation2file_patterns = operation2file_patterns
self.bucket_name = bucket_name
def get_object_name(self, queue_item_body: dict):
def get_input_object_name(self, queue_item_body: dict):
return self.get_object_name(queue_item_body, end="input")
file_descriptor = self.build_file_descriptor(queue_item_body)
def get_output_object_name(self, queue_item_body: dict):
return self.get_object_name(queue_item_body, end="output")
def get_object_name(self, queue_item_body: dict, end):
file_descriptor = self.build_file_descriptor(queue_item_body, end=end)
file_descriptor["pages"] = [queue_item_body.get("id", 0)]
object_name = self.__build_matcher(file_descriptor)
return object_name
def build_file_descriptor(self, queue_item_body, end="input"):
operation = queue_item_body.get("operation", "default")
file_pattern = self.operation2file_patterns[operation][end]
file_descriptor = {
**project(queue_item_body, ["dossierId", "fileId", "pages"]),
"pages": queue_item_body.get("pages", []),
"extension": file_pattern["extension"],
"subdir": file_pattern["subdir"],
}
return file_descriptor
def build_matcher(self, queue_item_body):
file_descriptor = self.build_file_descriptor(queue_item_body)
return self.__build_matcher(file_descriptor)
def __build_matcher(self, file_descriptor):
@staticmethod
def __build_matcher(file_descriptor):
def make_filename(file_id, subdir, suffix):
return os.path.join(file_id, subdir, suffix) if subdir else f"{file_id}.{suffix}"
@ -57,8 +44,37 @@ class FileDescriptorManager:
return matcher
def get_object_descriptor(self, queue_item_body):
return {
"bucket_name": parse_disjunction_string(CONFIG.storage.bucket),
"object_name": self.get_object_name(queue_item_body),
def build_file_descriptor(self, queue_item_body, end="input"):
operation = queue_item_body.get("operation", "default")
file_pattern = self.operation2file_patterns[operation][end]
file_descriptor = {
**project(queue_item_body, ["dossierId", "fileId", "pages"]),
"pages": queue_item_body.get("pages", []),
"extension": file_pattern["extension"],
"subdir": file_pattern["subdir"],
}
return file_descriptor
def build_input_matcher(self, queue_item_body):
return self.build_matcher(queue_item_body, end="input")
def build_output_matcher(self, queue_item_body):
return self.build_matcher(queue_item_body, end="output")
def build_matcher(self, queue_item_body, end):
file_descriptor = self.build_file_descriptor(queue_item_body, end=end)
return self.__build_matcher(file_descriptor)
def get_input_object_descriptor(self, queue_item_body):
return self.get_object_descriptor(queue_item_body, end="input")
def get_output_object_descriptor(self, queue_item_body):
return self.get_object_descriptor(queue_item_body, end="output")
def get_object_descriptor(self, queue_item_body, end):
return {
"bucket_name": self.bucket_name,
"object_name": self.get_object_name(queue_item_body, end=end),
}

View File

@ -29,7 +29,7 @@ class MultiDownloadStrategy(DownloadStrategy):
return map(compose(decompress, download), object_names)
def download(self, storage: Storage, queue_item_body):
file_pattern = self.file_descriptor_manager.build_matcher(queue_item_body)
file_pattern = self.file_descriptor_manager.build_input_matcher(queue_item_body)
page_object_names = self.get_names_of_objects_by_pages(storage, file_pattern)
objects = self.download_and_decompress_object(storage, page_object_names)

View File

@ -14,7 +14,7 @@ class SingleDownloadStrategy(DownloadStrategy):
return self._load_data(storage, queue_item_body)
def _load_data(self, storage, queue_item_body):
object_descriptor = self.file_descriptor_manager.get_object_descriptor(queue_item_body)
object_descriptor = self.file_descriptor_manager.get_input_object_descriptor(queue_item_body)
logging.debug(f"Downloading {object_descriptor}...")
data = self.__download(storage, object_descriptor)
logging.debug(f"Downloaded {object_descriptor}.")

View File

@ -134,7 +134,7 @@ def upload_data_to_storage_and_publish_requests_to_queue(
# TODO: refactor; too many params
def upload_data_to_storage_and_publish_request_to_queue(storage, queue_manager, data, message, file_descriptor_manager):
storage.put_object(**file_descriptor_manager.get_object_descriptor(message), data=compress(data))
storage.put_object(**file_descriptor_manager.get_input_object_descriptor(message), data=compress(data))
queue_manager.publish_request(message)
@ -148,7 +148,7 @@ def upload_data_to_folder_in_storage_and_publish_single_request_to_queue(
pages = ref_message["pages"]
for data, page in zip(map(first, data_message_pairs), pages):
object_descriptor = file_descriptor_manager.get_object_descriptor(ref_message)
object_descriptor = file_descriptor_manager.get_input_object_descriptor(ref_message)
object_descriptor["object_name"] = build_filepath(object_descriptor, page)
storage.put_object(**object_descriptor, data=compress(data))

View File

@ -19,21 +19,21 @@ class TestVisitor:
self, visitor, body, storage, bucket_name, file_descriptor_manager
):
storage.clear_bucket(bucket_name)
storage.put_object(**file_descriptor_manager.get_object_descriptor(body), data=pack_for_upload(b"content"))
storage.put_object(**file_descriptor_manager.get_input_object_descriptor(body), data=pack_for_upload(b"content"))
data_received = list(visitor.load_data(body))
assert [{"data": b"content", "metadata": {}}] == data_received
@pytest.mark.parametrize("response_strategy_name", ["forwarding", "storage"], scope="session")
def test_visitor_pulls_and_processes_data(self, visitor, body, storage, bucket_name, file_descriptor_manager):
storage.clear_bucket(bucket_name)
storage.put_object(**file_descriptor_manager.get_object_descriptor(body), data=pack_for_upload(b"2"))
storage.put_object(**file_descriptor_manager.get_input_object_descriptor(body), data=pack_for_upload(b"2"))
response_body = visitor.load_items_from_storage_and_process_with_callback(body)
assert response_body["analysis_payloads"] == ["22"]
@pytest.mark.parametrize("response_strategy_name", ["storage"], scope="session")
def test_visitor_puts_response_on_storage(self, visitor, body, storage, bucket_name, file_descriptor_manager):
storage.clear_bucket(bucket_name)
storage.put_object(**file_descriptor_manager.get_object_descriptor(body), data=pack_for_upload(b"2"))
storage.put_object(**file_descriptor_manager.get_input_object_descriptor(body), data=pack_for_upload(b"2"))
response_body = visitor(body)
assert "data" not in response_body
assert json.loads(