completed multi download to single response logic. but broke pipeline test again, maybe?
This commit is contained in:
parent
fa3b08aef5
commit
a69f613fe6
@ -27,7 +27,7 @@ def get_consumer(callback=None):
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_visitor(callback):
|
||||
return QueueVisitor(get_storage(), callback, get_response_strategy())
|
||||
return QueueVisitor(storage=get_storage(), callback=callback, response_strategy=get_response_strategy())
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
|
||||
@ -109,7 +109,6 @@ class PikaQueueManager(QueueManager):
|
||||
n_attempts = get_n_previous_attempts(properties) + 1
|
||||
|
||||
try:
|
||||
|
||||
response_messages = visitor(json.loads(body))
|
||||
|
||||
if isinstance(response_messages, dict):
|
||||
|
||||
@ -4,10 +4,11 @@ import hashlib
|
||||
import json
|
||||
import logging
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
from operator import itemgetter
|
||||
from typing import Callable, Dict, Union
|
||||
|
||||
from funcy import omit
|
||||
from funcy import omit, filter
|
||||
from more_itertools import peekable
|
||||
|
||||
from pyinfra.config import CONFIG, parse_disjunction_string
|
||||
@ -16,6 +17,7 @@ from pyinfra.parser.parser_composer import EitherParserComposer
|
||||
from pyinfra.parser.parsers.identity import IdentityBlobParser
|
||||
from pyinfra.parser.parsers.json import JsonBlobParser
|
||||
from pyinfra.parser.parsers.string import StringBlobParser
|
||||
from pyinfra.server.dispatcher.dispatcher import Nothing, is_not_nothing
|
||||
from pyinfra.server.packing import string_to_bytes
|
||||
from pyinfra.storage.storage import Storage
|
||||
|
||||
@ -30,8 +32,21 @@ def unique_hash(pages):
|
||||
|
||||
|
||||
def get_object_name(body):
|
||||
dossier_id, file_id, target_file_extension = itemgetter("dossierId", "fileId", "targetFileExtension")(body)
|
||||
object_name = f"{dossier_id}/{file_id}.{target_file_extension}"
|
||||
|
||||
body = deepcopy(body)
|
||||
|
||||
if "pages" not in body:
|
||||
body["pages"] = []
|
||||
|
||||
if "id" not in body:
|
||||
body["id"] = 0
|
||||
|
||||
dossier_id, file_id, pages, idnt, target_file_extension = itemgetter(
|
||||
"dossierId", "fileId", "pages", "id", "targetFileExtension"
|
||||
)(body)
|
||||
|
||||
object_name = f"{dossier_id}/{file_id}/{unique_hash(pages)}-id:{idnt}.{CONFIG.service.response_file_extension}"
|
||||
|
||||
return object_name
|
||||
|
||||
|
||||
@ -43,9 +58,7 @@ def get_response_object_name(body):
|
||||
if "id" not in body:
|
||||
body["id"] = 0
|
||||
|
||||
dossier_id, file_id, pages, idnt= itemgetter(
|
||||
"dossierId", "fileId", "pages", "id"
|
||||
)(body)
|
||||
dossier_id, file_id, pages, idnt = itemgetter("dossierId", "fileId", "pages", "id")(body)
|
||||
|
||||
object_name = f"{dossier_id}/{file_id}/{unique_hash(pages)}-id:{idnt}.{CONFIG.service.response_file_extension}"
|
||||
|
||||
@ -65,7 +78,7 @@ def get_response_object_descriptor(body):
|
||||
|
||||
class ResponseStrategy(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def handle_response(self, analysis_payload: dict):
|
||||
def handle_response(self, analysis_response: dict):
|
||||
pass
|
||||
|
||||
def __call__(self, analysis_payload: dict):
|
||||
@ -135,11 +148,10 @@ class AggregationStorageStrategy(ResponseStrategy):
|
||||
return self.put_object(data, storage_upload_info)
|
||||
|
||||
def upload_or_aggregate(self, analysis_payload, request_metadata, last=False):
|
||||
"""
|
||||
analysis_payload : {data: ..., metadata: ...}
|
||||
"""
|
||||
"""analysis_payload : {data: ..., metadata: ...}"""
|
||||
|
||||
storage_upload_info = build_storage_upload_info(analysis_payload, request_metadata)
|
||||
analysis_payload["metadata"].pop("id")
|
||||
|
||||
if analysis_payload["data"]:
|
||||
return self.put_object(json.dumps(analysis_payload).encode(), storage_upload_info)
|
||||
@ -148,15 +160,17 @@ class AggregationStorageStrategy(ResponseStrategy):
|
||||
self.buffer.append(analysis_payload)
|
||||
if last or self.dispatch_callback(storage_upload_info):
|
||||
return self.upload_queue_items(storage_upload_info)
|
||||
else:
|
||||
return Nothing
|
||||
|
||||
def handle_response(self, analysis_payload, final=False):
|
||||
def upload_or_aggregate(analysis_payload):
|
||||
return self.upload_or_aggregate(analysis_payload, request_metadata, last=not result_data.peek(False))
|
||||
|
||||
request_metadata = omit(analysis_payload, ["data"])
|
||||
result_data = peekable(analysis_payload["data"])
|
||||
for analysis_payload in result_data:
|
||||
print("------------------------------------------------------")
|
||||
print("analysis_payload", analysis_payload)
|
||||
print("request_metadata", request_metadata)
|
||||
yield self.upload_or_aggregate(analysis_payload, request_metadata, last=not result_data.peek(False))
|
||||
|
||||
yield from filter(is_not_nothing, map(upload_or_aggregate, result_data))
|
||||
|
||||
|
||||
def build_storage_upload_info(analysis_payload, request_metadata):
|
||||
@ -227,20 +241,25 @@ class QueueVisitor:
|
||||
|
||||
def load_data(self, queue_item_body):
|
||||
data = self.download_strategy(self.storage, queue_item_body)
|
||||
data = self.parsing_strategy(data)
|
||||
data = standardize(data)
|
||||
data = map(self.parsing_strategy, data)
|
||||
data = map(standardize, data)
|
||||
return data
|
||||
|
||||
def process_storage_item(self, data_metadata_pack):
|
||||
print(data_metadata_pack)
|
||||
return self.callback(data_metadata_pack)
|
||||
|
||||
def load_item_from_storage_and_process_with_callback(self, queue_item_body):
|
||||
"""Bundles the result from processing a storage item with the body of the corresponding queue item."""
|
||||
storage_item = self.load_data(queue_item_body)
|
||||
analysis_input = {**storage_item, **queue_item_body}
|
||||
result = self.process_storage_item(analysis_input)
|
||||
result_body = {"data": result, **queue_item_body}
|
||||
|
||||
storage_items = self.load_data(queue_item_body)
|
||||
|
||||
result_body = {"data": [], **queue_item_body}
|
||||
|
||||
for storage_item in storage_items:
|
||||
analysis_input = {**storage_item, **queue_item_body}
|
||||
result = self.process_storage_item(analysis_input)
|
||||
result_body["data"].extend(result)
|
||||
|
||||
return result_body
|
||||
|
||||
def __call__(self, queue_item_body):
|
||||
@ -279,12 +298,12 @@ def standardize(data) -> Dict:
|
||||
return wrap(string_to_bytes(data))
|
||||
|
||||
|
||||
def get_download_strategy():
|
||||
def get_download_strategy(download_strategy_type=None):
|
||||
download_strategies = {
|
||||
"single": SingleDownloadStrategy(),
|
||||
# "multi": MultiDownloadStratey(),
|
||||
"multi": MultiDownloadStrategy(),
|
||||
}
|
||||
return download_strategies.get(CONFIG.download_strategy, SingleDownloadStrategy())
|
||||
return download_strategies.get(download_strategy_type or CONFIG.download_strategy, SingleDownloadStrategy())
|
||||
|
||||
|
||||
class DownloadStrategy(abc.ABC):
|
||||
@ -295,7 +314,7 @@ class DownloadStrategy(abc.ABC):
|
||||
logging.debug(f"Downloaded {object_descriptor}.")
|
||||
assert isinstance(data, bytes)
|
||||
data = gzip.decompress(data)
|
||||
return data
|
||||
return [data]
|
||||
|
||||
@staticmethod
|
||||
def __download(storage, object_descriptor):
|
||||
@ -307,25 +326,30 @@ class DownloadStrategy(abc.ABC):
|
||||
|
||||
return data
|
||||
|
||||
# def __call__(self, storage, queue_item_body):
|
||||
# return self._load_data(storage, queue_item_body)
|
||||
|
||||
|
||||
class SingleDownloadStrategy(DownloadStrategy):
|
||||
def download(self, storage, object_descriptor):
|
||||
return self._load_data(storage, object_descriptor)
|
||||
def download(self, storage, queue_item_body):
|
||||
return self._load_data(storage, queue_item_body)
|
||||
|
||||
def __call__(self, storage, queue_item_body):
|
||||
return self.download(storage, queue_item_body)
|
||||
|
||||
|
||||
# class MultiDownloadStratey(DownloadStratey):
|
||||
#
|
||||
# def download(self, object_descriptor):
|
||||
# try:
|
||||
# data = self.storage.get_object(**object_descriptor)
|
||||
# except Exception as err:
|
||||
# logging.warning(f"Loading data from storage failed for {object_descriptor}.")
|
||||
# raise DataLoadingFailure from err
|
||||
#
|
||||
# return data
|
||||
class MultiDownloadStrategy(DownloadStrategy):
|
||||
def __init__(self):
|
||||
# TODO: pass in bucket name from outside / introduce closure-like abstraction for the bucket
|
||||
self.bucket_name = parse_disjunction_string(CONFIG.storage.bucket)
|
||||
|
||||
def download(self, storage: Storage, queue_item_body):
|
||||
pages = "|".join(map(str, queue_item_body["pages"]))
|
||||
matches_page = r".*-id:(" + pages + r").*"
|
||||
|
||||
object_names = storage.get_all_object_names(self.bucket_name)
|
||||
object_names = filter(matches_page, object_names)
|
||||
objects = (storage.get_object(self.bucket_name, objn) for objn in object_names)
|
||||
objects = map(gzip.decompress, objects)
|
||||
|
||||
return objects
|
||||
|
||||
def __call__(self, storage, queue_item_body):
|
||||
return self.download(storage, queue_item_body)
|
||||
|
||||
@ -77,8 +77,8 @@ def mock_make_load_data():
|
||||
return load_data
|
||||
|
||||
|
||||
def pytest_make_parametrize_id(config, val, argname):
|
||||
return f"\n\t{argname}={val}\n"
|
||||
# def pytest_make_parametrize_id(config, val, argname):
|
||||
# return f"\n\t{argname}={val}\n"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
2
test/fixtures/input.py
vendored
2
test/fixtures/input.py
vendored
@ -94,7 +94,7 @@ def targets(data_message_pairs, input_data_items, operation, metadata):
|
||||
try:
|
||||
response_data, response_metadata = zip(*map(unpack, flatten(starmap(op, zip(input_data_items, metadata)))))
|
||||
|
||||
queue_message_keys = second(first(pair_data_with_queue_message([b""]))).keys()
|
||||
queue_message_keys = [*second(first(pair_data_with_queue_message([b""]))).keys(), "id"]
|
||||
response_metadata = lmap(partial(omit, keys=queue_message_keys), response_metadata)
|
||||
expected = lzip(response_data, response_metadata)
|
||||
|
||||
|
||||
16
test/fixtures/server.py
vendored
16
test/fixtures/server.py
vendored
@ -1,7 +1,9 @@
|
||||
import io
|
||||
import logging
|
||||
import socket
|
||||
from collections import Counter
|
||||
from multiprocessing import Process
|
||||
from operator import itemgetter
|
||||
from typing import Generator
|
||||
|
||||
import fitz
|
||||
@ -60,18 +62,26 @@ def operation_conditionally_batched(operation, batched):
|
||||
|
||||
@pytest.fixture
|
||||
def operation(core_operation):
|
||||
auto_counter = Counter()
|
||||
|
||||
def auto_count(metadata):
|
||||
idnt = itemgetter("dossierId", "fileId")(metadata)
|
||||
auto_counter[idnt] += 1
|
||||
return {**metadata, "id": auto_counter[idnt]} if "id" not in metadata else metadata
|
||||
|
||||
def op(data, metadata):
|
||||
assert isinstance(metadata, dict)
|
||||
result = core_operation(data, metadata)
|
||||
if isinstance(result, Generator):
|
||||
for data, metadata in result:
|
||||
yield data, omit(metadata, ["pages", "operation"])
|
||||
yield data, auto_count(omit(metadata, ["pages", "operation"]))
|
||||
else:
|
||||
data, metadata = result
|
||||
yield data, omit(metadata, ["pages", "operation"])
|
||||
yield data, auto_count(omit(metadata, ["pages", "operation"]))
|
||||
|
||||
if core_operation is Nothing:
|
||||
return Nothing
|
||||
|
||||
return op
|
||||
|
||||
|
||||
@ -94,7 +104,7 @@ def core_operation(item_type, one_to_many, analysis_task):
|
||||
return image_to_bytes(im.rotate(90)), metadata
|
||||
|
||||
def classify(_: bytes, metadata):
|
||||
return b"", {"classification": 1}
|
||||
return b"", {"classification": 1, **metadata}
|
||||
|
||||
def stream_pages(pdf: bytes, metadata):
|
||||
for i, page in enumerate(fitz.open(stream=pdf)):
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import gzip
|
||||
import json
|
||||
import re
|
||||
from itertools import starmap, repeat, chain
|
||||
from operator import itemgetter
|
||||
|
||||
import pytest
|
||||
from funcy import compose, lpluck
|
||||
from funcy import compose, lpluck, first, second
|
||||
|
||||
from pyinfra.default_objects import (
|
||||
get_callback,
|
||||
@ -15,7 +16,7 @@ from pyinfra.default_objects import (
|
||||
)
|
||||
from pyinfra.queue.consumer import Consumer
|
||||
from pyinfra.server.packing import unpack, pack
|
||||
from pyinfra.visitor import get_object_descriptor, QueueVisitor
|
||||
from pyinfra.visitor import get_object_descriptor, QueueVisitor, get_download_strategy
|
||||
from test.utils.input import pair_data_with_queue_message
|
||||
|
||||
|
||||
@ -57,7 +58,7 @@ from test.utils.input import pair_data_with_queue_message
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("n_pages", [2])
|
||||
@pytest.mark.parametrize("buffer_size", [2])
|
||||
@pytest.mark.parametrize("buffer_size", [1, 2])
|
||||
@pytest.mark.parametrize(
|
||||
"item_type",
|
||||
[
|
||||
@ -98,7 +99,6 @@ from test.utils.input import pair_data_with_queue_message
|
||||
],
|
||||
)
|
||||
def test_serving(server_process, bucket_name, components, targets, data_message_pairs, n_items, many_to_n):
|
||||
print()
|
||||
|
||||
storage, queue_manager, consumer = components
|
||||
|
||||
@ -106,33 +106,18 @@ def test_serving(server_process, bucket_name, components, targets, data_message_
|
||||
assert queue_manager.output_queue.to_list() == []
|
||||
assert [*storage.get_all_object_names(bucket_name)] == []
|
||||
|
||||
if n_items:
|
||||
assert data_message_pairs
|
||||
|
||||
if many_to_n:
|
||||
# for data, message in data_message_pairs:
|
||||
# # storage.put_object(**get_object_descriptor(message), data=gzip.compress(data))
|
||||
# print("message", message)
|
||||
# queue_manager.publish_request(message)
|
||||
|
||||
upload_data_to_storage_and_publish_requests_to_queue(storage, queue_manager, data_message_pairs)
|
||||
storage.clear_bucket(bucket_name)
|
||||
queue_manager.clear()
|
||||
outputs = get_data_uploaded_by_consumer(queue_manager, storage, bucket_name)
|
||||
# upload_data_to_folder_in_storage_and_publish_single_request_to_queue(storage, queue_manager, data_message_pairs)
|
||||
return
|
||||
upload_data_to_folder_in_storage_and_publish_single_request_to_queue(storage, queue_manager, data_message_pairs)
|
||||
else:
|
||||
print(22222222222222222222222222222222222222222222222222222222222222222222)
|
||||
if n_items:
|
||||
assert data_message_pairs
|
||||
upload_data_to_storage_and_publish_requests_to_queue(storage, queue_manager, data_message_pairs)
|
||||
|
||||
print(33333333333333333333333333333333333333333333333)
|
||||
consumer.consume_and_publish(n=int(many_to_n) or n_items)
|
||||
|
||||
consumer.consume_and_publish(n=n_items)
|
||||
|
||||
print(44444444444444444444444444444444444444444444444)
|
||||
print([*storage.get_all_object_names(bucket_name)])
|
||||
outputs = get_data_uploaded_by_consumer(queue_manager, storage, bucket_name)
|
||||
|
||||
print("BLYAT", data_message_pairs)
|
||||
# TODO: correctness of target should be validated as well, since production was become non-trivial
|
||||
assert sorted(outputs) == sorted(targets)
|
||||
|
||||
@ -158,18 +143,30 @@ def upload_data_to_storage_and_publish_request_to_queue(storage, queue_manager,
|
||||
queue_manager.publish_request(message)
|
||||
|
||||
|
||||
# def upload_data_to_folder_in_storage_and_publish_single_request_to_queue(storage, queue_manager, data_message_pairs):
|
||||
# print()
|
||||
# print(22222222222222222222222222222222222222222)
|
||||
# assert data_message_pairs
|
||||
# for i, (data, message) in enumerate(data_message_pairs):
|
||||
# print(i)
|
||||
# object_descriptor = get_object_descriptor(message)
|
||||
# object_name = object_descriptor["object_name"]
|
||||
# object_descriptor["object_name"] = f"{object_name}/pages/{i}"
|
||||
# storage.put_object(**object_descriptor, data=gzip.compress(data))
|
||||
#
|
||||
# queue_manager.publish_request(message)
|
||||
# TODO: refactor
|
||||
def upload_data_to_folder_in_storage_and_publish_single_request_to_queue(storage, queue_manager, data_message_pairs):
|
||||
assert data_message_pairs
|
||||
|
||||
ref_message = second(first(data_message_pairs))
|
||||
pages = ref_message["pages"]
|
||||
|
||||
for data, page in zip(map(first, data_message_pairs), pages):
|
||||
object_descriptor = get_object_descriptor(ref_message)
|
||||
object_descriptor["object_name"] = build_filepath(object_descriptor, page)
|
||||
|
||||
storage.put_object(**object_descriptor, data=gzip.compress(data))
|
||||
|
||||
queue_manager.publish_request(ref_message)
|
||||
|
||||
|
||||
def build_filepath(object_descriptor, page):
|
||||
object_name = object_descriptor["object_name"]
|
||||
parts = object_name.split("/")
|
||||
parts.insert(-1, "pages")
|
||||
path = "/".join(parts)
|
||||
path = re.sub("id:\d", f"id:{page}", path)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def get_data_uploaded_by_consumer(queue_manager, storage, bucket_name):
|
||||
@ -196,8 +193,6 @@ def components(components_type, real_components, test_components, bucket_name):
|
||||
yield storage, queue_manager, consumer
|
||||
|
||||
queue_manager.clear()
|
||||
print()
|
||||
print("queue", queue_manager.input_queue.to_list())
|
||||
storage.clear_bucket(bucket_name)
|
||||
|
||||
|
||||
@ -215,11 +210,13 @@ def components_type(request):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_components(url):
|
||||
def real_components(url, many_to_n):
|
||||
callback = get_callback(url)
|
||||
consumer = get_consumer(callback)
|
||||
queue_manager = get_queue_manager()
|
||||
storage = get_storage()
|
||||
|
||||
consumer.visitor.download_strategy = get_download_strategy("multi" if many_to_n else "single")
|
||||
return storage, queue_manager, consumer
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user