From 892a803726946876f8b8cd7905a0e73c419b2fb1 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Tue, 28 Mar 2023 14:41:49 +0200 Subject: [PATCH] Refactoring Replace custom storage caching logic with LRU decorator --- pyinfra/payload_processing/processor.py | 8 ++--- pyinfra/storage/storage_info.py | 32 ----------------- tests/lru_test.py | 48 +++++++++++++++++++++++++ tests/storage_cache_test.py | 47 ------------------------ 4 files changed, 50 insertions(+), 85 deletions(-) create mode 100644 tests/lru_test.py delete mode 100644 tests/storage_cache_test.py diff --git a/pyinfra/payload_processing/processor.py b/pyinfra/payload_processing/processor.py index 2193d85..37f31a1 100644 --- a/pyinfra/payload_processing/processor.py +++ b/pyinfra/payload_processing/processor.py @@ -16,7 +16,7 @@ from pyinfra.storage.storage_info import ( get_storage_info_from_config, get_storage_info_from_endpoint, StorageInfo, - DefaultStorageCache, + get_storage_from_storage_info, ) logger = logging.getLogger() @@ -27,7 +27,6 @@ class PayloadProcessor: def __init__( self, default_storage_info: StorageInfo, - storage_cache: DefaultStorageCache, get_storage_info_from_tenant_id, payload_parser: QueueMessagePayloadParser, payload_formatter: QueueMessagePayloadFormatter, @@ -52,7 +51,6 @@ class PayloadProcessor: self.get_storage_info_from_tenant_id = get_storage_info_from_tenant_id self.default_storage_info = default_storage_info - self.storages = storage_cache def __call__(self, queue_message_payload: dict) -> dict: """Processes a queue message payload. @@ -77,7 +75,7 @@ class PayloadProcessor: logger.info(f"Processing {asdict(payload)} ...") storage_info = self._get_storage_info(payload.x_tenant_id) - storage = self.storages[storage_info] + storage = get_storage_from_storage_info(storage_info) bucket = storage_info.bucket_name download_file_to_process = make_downloader( @@ -112,7 +110,6 @@ def make_payload_processor(data_processor: Callable, config: Config = None) -> P """Produces payload processor for queue manager.""" config = config or get_config() - storage_cache: DefaultStorageCache = DefaultStorageCache(max_size=10) default_storage_info: StorageInfo = get_storage_info_from_config(config) get_storage_info_from_tenant_id = partial( get_storage_info_from_endpoint, @@ -127,7 +124,6 @@ def make_payload_processor(data_processor: Callable, config: Config = None) -> P return PayloadProcessor( default_storage_info, - storage_cache, get_storage_info_from_tenant_id, payload_parser, payload_formatter, diff --git a/pyinfra/storage/storage_info.py b/pyinfra/storage/storage_info.py index 65b1084..aefa15b 100644 --- a/pyinfra/storage/storage_info.py +++ b/pyinfra/storage/storage_info.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Callable import requests from azure.storage.blob import BlobServiceClient @@ -72,37 +71,6 @@ def get_storage_from_storage_info(storage_info: StorageInfo) -> Storage: raise UnknownStorageBackend() -class DefaultStorageCache(dict): - def __init__(self, max_size=None, get_value: Callable = get_storage_from_storage_info, *args, **kwargs): - super().__init__(*args, **kwargs) - self.max_size = max_size - self.get_value = get_value - self.access_history = [] - - def __getitem__(self, key: StorageInfo): - self.access_history.append(key) - value = super().get(key) - if not value: - self.__setitem__(key, value) - value = self[key] - - return value - - def __setitem__(self, key: StorageInfo, value: Storage = None): - self._delete_oldest_key_if_max_size_breached() - value = value or self.get_value(key) - super().__setitem__(key, value) - self.access_history.append(key) - - def set_key(self, key: StorageInfo): - self.__setitem__(key) - - def _delete_oldest_key_if_max_size_breached(self): - if len(self) >= self.max_size: - oldest_key = self.access_history.pop(0) - del self[oldest_key] - - def get_storage_info_from_endpoint(public_key: str, endpoint: str, x_tenant_id: str) -> StorageInfo: resp = requests.get(f"{endpoint}/{x_tenant_id}").json() diff --git a/tests/lru_test.py b/tests/lru_test.py new file mode 100644 index 0000000..9ab574f --- /dev/null +++ b/tests/lru_test.py @@ -0,0 +1,48 @@ +from functools import lru_cache + +import pytest + + +def func(callback): + return callback() + + +@pytest.fixture() +def fn(maxsize): + return lru_cache(maxsize)(func) + + +@pytest.fixture(params=[1, 2, 5]) +def maxsize(request): + return request.param + + +class Callback: + def __init__(self, x): + self.initial_x = x + self.x = x + + def __call__(self, *args, **kwargs): + self.x += 1 + return self.x + + def __hash__(self): + return hash(self.initial_x) + + +def test_adding_to_cache_within_maxsize_does_not_overwrite(fn, maxsize): + c = Callback(0) + for i in range(maxsize): + assert fn(c) == 1 + assert fn(c) == 1 + + +def test_adding_to_cache_more_than_maxsize_does_overwrite(fn, maxsize): + + callbacks = [Callback(i) for i in range(maxsize)] + + for i in range(maxsize): + assert fn(callbacks[i]) == i + 1 + + assert fn(Callback(maxsize)) == maxsize + 1 + assert fn(callbacks[0]) == 2 diff --git a/tests/storage_cache_test.py b/tests/storage_cache_test.py deleted file mode 100644 index 21152e6..0000000 --- a/tests/storage_cache_test.py +++ /dev/null @@ -1,47 +0,0 @@ -from dataclasses import asdict - -import pytest - -from pyinfra.storage.storage_info import AzureStorageInfo, StorageInfo, DefaultStorageCache - - -@pytest.fixture -def storage_infos(): - return [ - AzureStorageInfo(connection_string="first", bucket_name="Thorsten"), - AzureStorageInfo(connection_string="first", bucket_name="Klaus"), - AzureStorageInfo(connection_string="second", bucket_name="Thorsten"), - AzureStorageInfo(connection_string="third", bucket_name="Klaus"), - ] - - -@pytest.fixture -def storage_cache(max_size): - def get_connection_from_storage_info_mock(storage_info: StorageInfo): - return asdict(storage_info) - - return DefaultStorageCache(max_size=max_size, get_value=get_connection_from_storage_info_mock) - - -class TestStorageCache: - def test_same_storage_info_has_same_hash(self, storage_infos): - assert storage_infos[0].__hash__() == storage_infos[1].__hash__() - - @pytest.mark.parametrize("max_size", [4]) - def test_same_connection_different_bucket_does_not_create_new_connection(self, storage_cache, storage_infos): - value = storage_cache[storage_infos[0]] - assert value == asdict((storage_infos[0])) - - value = storage_cache[storage_infos[1]] - assert value == asdict((storage_infos[0])) - - @pytest.mark.parametrize("max_size", [2]) - def test_max_size_breached_removes_oldest_key(self, storage_cache, storage_infos): - storage_cache.set_key(storage_infos[0]) - storage_cache.set_key(storage_infos[2]) - storage_cache.set_key(storage_infos[3]) - - assert len(storage_cache) == 2 - assert storage_infos[0] not in storage_cache - assert storage_infos[2] in storage_cache - assert storage_infos[3] in storage_cache