made object name construction logic part of download strategies
This commit is contained in:
parent
116c2b8924
commit
8537d4af50
@ -37,5 +37,10 @@ class IntentionalTestException(RuntimeError):
|
||||
class UnexpectedItemType(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class NoBufferCapacity(ValueError):
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class InvalidMessage(ValueError):
|
||||
pass
|
||||
|
||||
@ -12,7 +12,7 @@ from funcy import omit, filter, lflatten
|
||||
from more_itertools import peekable
|
||||
|
||||
from pyinfra.config import CONFIG, parse_disjunction_string
|
||||
from pyinfra.exceptions import DataLoadingFailure
|
||||
from pyinfra.exceptions import DataLoadingFailure, InvalidMessage
|
||||
from pyinfra.parser.parser_composer import EitherParserComposer
|
||||
from pyinfra.parser.parsers.identity import IdentityBlobParser
|
||||
from pyinfra.parser.parsers.json import JsonBlobParser
|
||||
@ -24,56 +24,6 @@ from pyinfra.storage.storage import Storage
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def unique_hash(pages):
|
||||
pages_str = "-".join(map(str, pages))
|
||||
rand_str = pages_str.encode(encoding="UTF-8", errors="strict")
|
||||
hsh = hashlib.md5(rand_str).hexdigest()
|
||||
return hsh
|
||||
|
||||
|
||||
def get_object_name(body: dict):
|
||||
|
||||
def get_key(key):
|
||||
return key if key in body else False
|
||||
|
||||
body = deepcopy(body)
|
||||
|
||||
folder = f"/{folder}/" if (folder := get_key("pages") or get_key("images")) else ""
|
||||
idnt = f"id:{idnt}" if (idnt := body.get("id", "0" if folder else False)) else ""
|
||||
|
||||
dossier_id, file_id = itemgetter("dossierId", "fileId")(body)
|
||||
|
||||
object_name = f"{dossier_id}/{file_id}{folder}{idnt}.{CONFIG.service.target_file_extension}"
|
||||
|
||||
return object_name
|
||||
|
||||
|
||||
def get_response_object_name(body):
|
||||
|
||||
if "pages" not in body:
|
||||
body["pages"] = []
|
||||
|
||||
if "id" not in body:
|
||||
body["id"] = 0
|
||||
|
||||
dossier_id, file_id, pages, idnt = itemgetter("dossierId", "fileId", "pages", "id")(body)
|
||||
|
||||
object_name = f"{dossier_id}/{file_id}/{unique_hash(pages)}-id:{idnt}.{CONFIG.service.response_file_extension}"
|
||||
|
||||
return object_name
|
||||
|
||||
|
||||
def get_object_descriptor(body):
|
||||
return {"bucket_name": parse_disjunction_string(CONFIG.storage.bucket), "object_name": get_object_name(body)}
|
||||
|
||||
|
||||
def get_response_object_descriptor(body):
|
||||
return {
|
||||
"bucket_name": parse_disjunction_string(CONFIG.storage.bucket),
|
||||
"object_name": get_response_object_name(body),
|
||||
}
|
||||
|
||||
|
||||
class ResponseStrategy(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def handle_response(self, analysis_response: dict):
|
||||
@ -82,13 +32,34 @@ class ResponseStrategy(abc.ABC):
|
||||
def __call__(self, analysis_response: dict):
|
||||
return self.handle_response(analysis_response)
|
||||
|
||||
def get_response_object_descriptor(self, body):
|
||||
return {
|
||||
"bucket_name": parse_disjunction_string(CONFIG.storage.bucket),
|
||||
"object_name": self.get_response_object_name(body),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_response_object_name(body):
|
||||
|
||||
if "pages" not in body:
|
||||
body["pages"] = []
|
||||
|
||||
if "id" not in body:
|
||||
body["id"] = 0
|
||||
|
||||
dossier_id, file_id, pages, idnt = itemgetter("dossierId", "fileId", "pages", "id")(body)
|
||||
|
||||
object_name = f"{dossier_id}/{file_id}/id:{idnt}.{CONFIG.service.response_file_extension}"
|
||||
|
||||
return object_name
|
||||
|
||||
|
||||
class StorageStrategy(ResponseStrategy):
|
||||
def __init__(self, storage):
|
||||
self.storage = storage
|
||||
|
||||
def handle_response(self, body):
|
||||
response_object_descriptor = get_response_object_descriptor(body)
|
||||
response_object_descriptor = self.get_response_object_descriptor(body)
|
||||
self.storage.put_object(**response_object_descriptor, data=gzip.compress(json.dumps(body).encode()))
|
||||
body.pop("data")
|
||||
body["responseFile"] = response_object_descriptor["object_name"]
|
||||
@ -132,7 +103,7 @@ class AggregationStorageStrategy(ResponseStrategy):
|
||||
self.buffer = deque()
|
||||
|
||||
def put_object(self, data: bytes, storage_upload_info):
|
||||
object_descriptor = get_response_object_descriptor(storage_upload_info)
|
||||
object_descriptor = self.get_response_object_descriptor(storage_upload_info)
|
||||
self.storage.put_object(**object_descriptor, data=gzip.compress(data))
|
||||
return {**storage_upload_info, "responseFile": object_descriptor["object_name"]}
|
||||
|
||||
@ -303,7 +274,7 @@ def get_download_strategy(download_strategy_type=None):
|
||||
|
||||
class DownloadStrategy(abc.ABC):
|
||||
def _load_data(self, storage, queue_item_body):
|
||||
object_descriptor = get_object_descriptor(queue_item_body)
|
||||
object_descriptor = self.get_object_descriptor(queue_item_body)
|
||||
logging.debug(f"Downloading {object_descriptor}...")
|
||||
data = self.__download(storage, object_descriptor)
|
||||
logging.debug(f"Downloaded {object_descriptor}.")
|
||||
@ -321,11 +292,31 @@ class DownloadStrategy(abc.ABC):
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def get_object_name(body: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_object_descriptor(self, body):
|
||||
return {"bucket_name": parse_disjunction_string(CONFIG.storage.bucket), "object_name": self.get_object_name(body)}
|
||||
|
||||
|
||||
class SingleDownloadStrategy(DownloadStrategy):
|
||||
def download(self, storage, queue_item_body):
|
||||
return self._load_data(storage, queue_item_body)
|
||||
|
||||
@staticmethod
|
||||
def get_object_name(body: dict):
|
||||
|
||||
# TODO: deepcopy still necessary?
|
||||
body = deepcopy(body)
|
||||
|
||||
dossier_id, file_id = itemgetter("dossierId", "fileId")(body)
|
||||
|
||||
object_name = f"{dossier_id}/{file_id}.{CONFIG.service.target_file_extension}"
|
||||
|
||||
return object_name
|
||||
|
||||
def __call__(self, storage, queue_item_body):
|
||||
return self.download(storage, queue_item_body)
|
||||
|
||||
@ -346,5 +337,26 @@ class MultiDownloadStrategy(DownloadStrategy):
|
||||
|
||||
return objects
|
||||
|
||||
@staticmethod
|
||||
def get_object_name(body: dict):
|
||||
|
||||
def get_key(key):
|
||||
return key if key in body else False
|
||||
|
||||
# TODO: deepcopy still necessary?
|
||||
body = deepcopy(body)
|
||||
|
||||
folder = f"/{get_key('pages') or get_key('images')}/"
|
||||
if not folder:
|
||||
raise InvalidMessage("Expected a folder like 'images' oder 'pages' to be specified in message.")
|
||||
|
||||
idnt = f"id:{body.get('id', 0)}"
|
||||
|
||||
dossier_id, file_id = itemgetter("dossierId", "fileId")(body)
|
||||
|
||||
object_name = f"{dossier_id}/{file_id}{folder}{idnt}.{CONFIG.service.target_file_extension}"
|
||||
|
||||
return object_name
|
||||
|
||||
def __call__(self, storage, queue_item_body):
|
||||
return self.download(storage, queue_item_body)
|
||||
|
||||
@ -16,7 +16,7 @@ from pyinfra.default_objects import (
|
||||
)
|
||||
from pyinfra.queue.consumer import Consumer
|
||||
from pyinfra.server.packing import unpack, pack
|
||||
from pyinfra.visitor import get_object_descriptor, QueueVisitor, get_download_strategy
|
||||
from pyinfra.visitor import QueueVisitor, get_download_strategy
|
||||
from test.utils.input import pair_data_with_queue_message
|
||||
|
||||
|
||||
@ -37,7 +37,7 @@ from test.utils.input import pair_data_with_queue_message
|
||||
@pytest.mark.parametrize(
|
||||
"analysis_task",
|
||||
[
|
||||
False,
|
||||
# False,
|
||||
True,
|
||||
],
|
||||
)
|
||||
@ -89,7 +89,9 @@ from test.utils.input import pair_data_with_queue_message
|
||||
True,
|
||||
],
|
||||
)
|
||||
def test_serving(server_process, bucket_name, components, targets, data_message_pairs, n_items, many_to_n):
|
||||
def test_serving(
|
||||
server_process, bucket_name, components, targets, data_message_pairs, n_items, many_to_n, download_strategy
|
||||
):
|
||||
|
||||
storage, queue_manager, consumer = components
|
||||
|
||||
@ -101,9 +103,13 @@ def test_serving(server_process, bucket_name, components, targets, data_message_
|
||||
assert data_message_pairs
|
||||
|
||||
if many_to_n:
|
||||
upload_data_to_folder_in_storage_and_publish_single_request_to_queue(storage, queue_manager, data_message_pairs)
|
||||
upload_data_to_folder_in_storage_and_publish_single_request_to_queue(
|
||||
storage, queue_manager, data_message_pairs, download_strategy
|
||||
)
|
||||
else:
|
||||
upload_data_to_storage_and_publish_requests_to_queue(storage, queue_manager, data_message_pairs)
|
||||
upload_data_to_storage_and_publish_requests_to_queue(
|
||||
storage, queue_manager, data_message_pairs, download_strategy
|
||||
)
|
||||
|
||||
consumer.consume_and_publish(n=int(many_to_n) or n_items)
|
||||
|
||||
@ -124,25 +130,29 @@ def data_message_pairs(data_metadata_packs):
|
||||
return data_message_pairs
|
||||
|
||||
|
||||
def upload_data_to_storage_and_publish_requests_to_queue(storage, queue_manager, data_message_pairs):
|
||||
# TODO: refactor; too many params
|
||||
def upload_data_to_storage_and_publish_requests_to_queue(storage, queue_manager, data_message_pairs, download_strategy):
|
||||
for data, message in data_message_pairs:
|
||||
upload_data_to_storage_and_publish_request_to_queue(storage, queue_manager, data, message)
|
||||
upload_data_to_storage_and_publish_request_to_queue(storage, queue_manager, data, message, download_strategy)
|
||||
|
||||
|
||||
def upload_data_to_storage_and_publish_request_to_queue(storage, queue_manager, data, message):
|
||||
storage.put_object(**get_object_descriptor(message), data=gzip.compress(data))
|
||||
# TODO: refactor; too many params
|
||||
def upload_data_to_storage_and_publish_request_to_queue(storage, queue_manager, data, message, download_strategy):
|
||||
storage.put_object(**download_strategy.get_object_descriptor(message), data=gzip.compress(data))
|
||||
queue_manager.publish_request(message)
|
||||
|
||||
|
||||
# TODO: refactor
|
||||
def upload_data_to_folder_in_storage_and_publish_single_request_to_queue(storage, queue_manager, data_message_pairs):
|
||||
# TODO: refactor body; too long and scripty
|
||||
def upload_data_to_folder_in_storage_and_publish_single_request_to_queue(
|
||||
storage, queue_manager, data_message_pairs, download_strategy
|
||||
):
|
||||
assert data_message_pairs
|
||||
|
||||
ref_message = second(first(data_message_pairs))
|
||||
pages = ref_message["pages"]
|
||||
|
||||
for data, page in zip(map(first, data_message_pairs), pages):
|
||||
object_descriptor = get_object_descriptor(ref_message)
|
||||
object_descriptor = download_strategy.get_object_descriptor(ref_message)
|
||||
object_descriptor["object_name"] = build_filepath(object_descriptor, page)
|
||||
|
||||
storage.put_object(**object_descriptor, data=gzip.compress(data))
|
||||
@ -202,16 +212,21 @@ def components_type(request):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_components(url, many_to_n):
|
||||
def real_components(url, download_strategy):
|
||||
callback = get_callback(url)
|
||||
consumer = get_consumer(callback)
|
||||
queue_manager = get_queue_manager()
|
||||
storage = get_storage()
|
||||
|
||||
consumer.visitor.download_strategy = get_download_strategy("multi" if many_to_n else "single")
|
||||
consumer.visitor.download_strategy = download_strategy
|
||||
return storage, queue_manager, consumer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def download_strategy(many_to_n):
|
||||
return get_download_strategy("multi" if many_to_n else "single")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_components(url, queue_manager, storage):
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user