refactoring: introduced input- and output-file specific methods to file descr mngr
This commit is contained in:
parent
ecced37150
commit
2c80d7cec0
@ -3,6 +3,7 @@ from functools import lru_cache
|
||||
|
||||
from funcy import rcompose, omit, merge, lmap, project
|
||||
|
||||
from pyinfra.config import parse_disjunction_string
|
||||
from pyinfra.exceptions import AnalysisFailure
|
||||
from pyinfra.queue.consumer import Consumer
|
||||
from pyinfra.queue.queue_manager.pika_queue_manager import PikaQueueManager
|
||||
@ -89,7 +90,8 @@ class ComponentFactory:
|
||||
"multi": self.get_multi_download_strategy(),
|
||||
}
|
||||
return download_strategies.get(
|
||||
download_strategy_type or self.config.service.download_strategy, self.get_single_download_strategy()
|
||||
download_strategy_type or self.config.service.download_strategy,
|
||||
self.get_single_download_strategy(),
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
@ -102,7 +104,10 @@ class ComponentFactory:
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_file_descriptor_manager(self):
|
||||
return FileDescriptorManager(self.get_operation2file_patterns())
|
||||
return FileDescriptorManager(
|
||||
operation2file_patterns=self.get_operation2file_patterns(),
|
||||
bucket_name=parse_disjunction_string(self.config.storage.bucket),
|
||||
)
|
||||
|
||||
|
||||
class Callback:
|
||||
|
||||
@ -3,42 +3,29 @@ from _operator import itemgetter
|
||||
|
||||
from funcy import project
|
||||
|
||||
from pyinfra.config import parse_disjunction_string, CONFIG
|
||||
|
||||
|
||||
class FileDescriptorManager:
|
||||
def __init__(self, operation2file_patterns):
|
||||
# TODO: pass in bucket name from outside / introduce closure-like abstraction for the bucket
|
||||
self.bucket_name = parse_disjunction_string(CONFIG.storage.bucket)
|
||||
def __init__(self, operation2file_patterns, bucket_name):
|
||||
self.operation2file_patterns = operation2file_patterns
|
||||
self.bucket_name = bucket_name
|
||||
|
||||
def get_object_name(self, queue_item_body: dict):
|
||||
def get_input_object_name(self, queue_item_body: dict):
|
||||
return self.get_object_name(queue_item_body, end="input")
|
||||
|
||||
file_descriptor = self.build_file_descriptor(queue_item_body)
|
||||
def get_output_object_name(self, queue_item_body: dict):
|
||||
return self.get_object_name(queue_item_body, end="output")
|
||||
|
||||
def get_object_name(self, queue_item_body: dict, end):
|
||||
|
||||
file_descriptor = self.build_file_descriptor(queue_item_body, end=end)
|
||||
file_descriptor["pages"] = [queue_item_body.get("id", 0)]
|
||||
|
||||
object_name = self.__build_matcher(file_descriptor)
|
||||
|
||||
return object_name
|
||||
|
||||
def build_file_descriptor(self, queue_item_body, end="input"):
|
||||
operation = queue_item_body.get("operation", "default")
|
||||
|
||||
file_pattern = self.operation2file_patterns[operation][end]
|
||||
|
||||
file_descriptor = {
|
||||
**project(queue_item_body, ["dossierId", "fileId", "pages"]),
|
||||
"pages": queue_item_body.get("pages", []),
|
||||
"extension": file_pattern["extension"],
|
||||
"subdir": file_pattern["subdir"],
|
||||
}
|
||||
return file_descriptor
|
||||
|
||||
def build_matcher(self, queue_item_body):
|
||||
file_descriptor = self.build_file_descriptor(queue_item_body)
|
||||
return self.__build_matcher(file_descriptor)
|
||||
|
||||
def __build_matcher(self, file_descriptor):
|
||||
@staticmethod
|
||||
def __build_matcher(file_descriptor):
|
||||
def make_filename(file_id, subdir, suffix):
|
||||
return os.path.join(file_id, subdir, suffix) if subdir else f"{file_id}.{suffix}"
|
||||
|
||||
@ -57,8 +44,37 @@ class FileDescriptorManager:
|
||||
|
||||
return matcher
|
||||
|
||||
def get_object_descriptor(self, queue_item_body):
|
||||
return {
|
||||
"bucket_name": parse_disjunction_string(CONFIG.storage.bucket),
|
||||
"object_name": self.get_object_name(queue_item_body),
|
||||
def build_file_descriptor(self, queue_item_body, end="input"):
|
||||
operation = queue_item_body.get("operation", "default")
|
||||
|
||||
file_pattern = self.operation2file_patterns[operation][end]
|
||||
|
||||
file_descriptor = {
|
||||
**project(queue_item_body, ["dossierId", "fileId", "pages"]),
|
||||
"pages": queue_item_body.get("pages", []),
|
||||
"extension": file_pattern["extension"],
|
||||
"subdir": file_pattern["subdir"],
|
||||
}
|
||||
return file_descriptor
|
||||
|
||||
def build_input_matcher(self, queue_item_body):
|
||||
return self.build_matcher(queue_item_body, end="input")
|
||||
|
||||
def build_output_matcher(self, queue_item_body):
|
||||
return self.build_matcher(queue_item_body, end="output")
|
||||
|
||||
def build_matcher(self, queue_item_body, end):
|
||||
file_descriptor = self.build_file_descriptor(queue_item_body, end=end)
|
||||
return self.__build_matcher(file_descriptor)
|
||||
|
||||
def get_input_object_descriptor(self, queue_item_body):
|
||||
return self.get_object_descriptor(queue_item_body, end="input")
|
||||
|
||||
def get_output_object_descriptor(self, queue_item_body):
|
||||
return self.get_object_descriptor(queue_item_body, end="output")
|
||||
|
||||
def get_object_descriptor(self, queue_item_body, end):
|
||||
return {
|
||||
"bucket_name": self.bucket_name,
|
||||
"object_name": self.get_object_name(queue_item_body, end=end),
|
||||
}
|
||||
|
||||
@ -29,7 +29,7 @@ class MultiDownloadStrategy(DownloadStrategy):
|
||||
return map(compose(decompress, download), object_names)
|
||||
|
||||
def download(self, storage: Storage, queue_item_body):
|
||||
file_pattern = self.file_descriptor_manager.build_matcher(queue_item_body)
|
||||
file_pattern = self.file_descriptor_manager.build_input_matcher(queue_item_body)
|
||||
|
||||
page_object_names = self.get_names_of_objects_by_pages(storage, file_pattern)
|
||||
objects = self.download_and_decompress_object(storage, page_object_names)
|
||||
|
||||
@ -14,7 +14,7 @@ class SingleDownloadStrategy(DownloadStrategy):
|
||||
return self._load_data(storage, queue_item_body)
|
||||
|
||||
def _load_data(self, storage, queue_item_body):
|
||||
object_descriptor = self.file_descriptor_manager.get_object_descriptor(queue_item_body)
|
||||
object_descriptor = self.file_descriptor_manager.get_input_object_descriptor(queue_item_body)
|
||||
logging.debug(f"Downloading {object_descriptor}...")
|
||||
data = self.__download(storage, object_descriptor)
|
||||
logging.debug(f"Downloaded {object_descriptor}.")
|
||||
|
||||
@ -134,7 +134,7 @@ def upload_data_to_storage_and_publish_requests_to_queue(
|
||||
|
||||
# TODO: refactor; too many params
|
||||
def upload_data_to_storage_and_publish_request_to_queue(storage, queue_manager, data, message, file_descriptor_manager):
|
||||
storage.put_object(**file_descriptor_manager.get_object_descriptor(message), data=compress(data))
|
||||
storage.put_object(**file_descriptor_manager.get_input_object_descriptor(message), data=compress(data))
|
||||
queue_manager.publish_request(message)
|
||||
|
||||
|
||||
@ -148,7 +148,7 @@ def upload_data_to_folder_in_storage_and_publish_single_request_to_queue(
|
||||
pages = ref_message["pages"]
|
||||
|
||||
for data, page in zip(map(first, data_message_pairs), pages):
|
||||
object_descriptor = file_descriptor_manager.get_object_descriptor(ref_message)
|
||||
object_descriptor = file_descriptor_manager.get_input_object_descriptor(ref_message)
|
||||
object_descriptor["object_name"] = build_filepath(object_descriptor, page)
|
||||
|
||||
storage.put_object(**object_descriptor, data=compress(data))
|
||||
|
||||
@ -19,21 +19,21 @@ class TestVisitor:
|
||||
self, visitor, body, storage, bucket_name, file_descriptor_manager
|
||||
):
|
||||
storage.clear_bucket(bucket_name)
|
||||
storage.put_object(**file_descriptor_manager.get_object_descriptor(body), data=pack_for_upload(b"content"))
|
||||
storage.put_object(**file_descriptor_manager.get_input_object_descriptor(body), data=pack_for_upload(b"content"))
|
||||
data_received = list(visitor.load_data(body))
|
||||
assert [{"data": b"content", "metadata": {}}] == data_received
|
||||
|
||||
@pytest.mark.parametrize("response_strategy_name", ["forwarding", "storage"], scope="session")
|
||||
def test_visitor_pulls_and_processes_data(self, visitor, body, storage, bucket_name, file_descriptor_manager):
|
||||
storage.clear_bucket(bucket_name)
|
||||
storage.put_object(**file_descriptor_manager.get_object_descriptor(body), data=pack_for_upload(b"2"))
|
||||
storage.put_object(**file_descriptor_manager.get_input_object_descriptor(body), data=pack_for_upload(b"2"))
|
||||
response_body = visitor.load_items_from_storage_and_process_with_callback(body)
|
||||
assert response_body["analysis_payloads"] == ["22"]
|
||||
|
||||
@pytest.mark.parametrize("response_strategy_name", ["storage"], scope="session")
|
||||
def test_visitor_puts_response_on_storage(self, visitor, body, storage, bucket_name, file_descriptor_manager):
|
||||
storage.clear_bucket(bucket_name)
|
||||
storage.put_object(**file_descriptor_manager.get_object_descriptor(body), data=pack_for_upload(b"2"))
|
||||
storage.put_object(**file_descriptor_manager.get_input_object_descriptor(body), data=pack_for_upload(b"2"))
|
||||
response_body = visitor(body)
|
||||
assert "data" not in response_body
|
||||
assert json.loads(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user