Refactoring

Replace custom storage caching logic with LRU decorator
This commit is contained in:
Matthias Bisping 2023-03-28 14:41:49 +02:00
parent eafcd90260
commit 892a803726
4 changed files with 50 additions and 85 deletions

View File

@ -16,7 +16,7 @@ from pyinfra.storage.storage_info import (
get_storage_info_from_config,
get_storage_info_from_endpoint,
StorageInfo,
DefaultStorageCache,
get_storage_from_storage_info,
)
logger = logging.getLogger()
@ -27,7 +27,6 @@ class PayloadProcessor:
def __init__(
self,
default_storage_info: StorageInfo,
storage_cache: DefaultStorageCache,
get_storage_info_from_tenant_id,
payload_parser: QueueMessagePayloadParser,
payload_formatter: QueueMessagePayloadFormatter,
@ -52,7 +51,6 @@ class PayloadProcessor:
self.get_storage_info_from_tenant_id = get_storage_info_from_tenant_id
self.default_storage_info = default_storage_info
self.storages = storage_cache
def __call__(self, queue_message_payload: dict) -> dict:
"""Processes a queue message payload.
@ -77,7 +75,7 @@ class PayloadProcessor:
logger.info(f"Processing {asdict(payload)} ...")
storage_info = self._get_storage_info(payload.x_tenant_id)
storage = self.storages[storage_info]
storage = get_storage_from_storage_info(storage_info)
bucket = storage_info.bucket_name
download_file_to_process = make_downloader(
@ -112,7 +110,6 @@ def make_payload_processor(data_processor: Callable, config: Config = None) -> P
"""Produces payload processor for queue manager."""
config = config or get_config()
storage_cache: DefaultStorageCache = DefaultStorageCache(max_size=10)
default_storage_info: StorageInfo = get_storage_info_from_config(config)
get_storage_info_from_tenant_id = partial(
get_storage_info_from_endpoint,
@ -127,7 +124,6 @@ def make_payload_processor(data_processor: Callable, config: Config = None) -> P
return PayloadProcessor(
default_storage_info,
storage_cache,
get_storage_info_from_tenant_id,
payload_parser,
payload_formatter,

View File

@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Callable
import requests
from azure.storage.blob import BlobServiceClient
@ -72,37 +71,6 @@ def get_storage_from_storage_info(storage_info: StorageInfo) -> Storage:
raise UnknownStorageBackend()
class DefaultStorageCache(dict):
def __init__(self, max_size=None, get_value: Callable = get_storage_from_storage_info, *args, **kwargs):
super().__init__(*args, **kwargs)
self.max_size = max_size
self.get_value = get_value
self.access_history = []
def __getitem__(self, key: StorageInfo):
self.access_history.append(key)
value = super().get(key)
if not value:
self.__setitem__(key, value)
value = self[key]
return value
def __setitem__(self, key: StorageInfo, value: Storage = None):
self._delete_oldest_key_if_max_size_breached()
value = value or self.get_value(key)
super().__setitem__(key, value)
self.access_history.append(key)
def set_key(self, key: StorageInfo):
self.__setitem__(key)
def _delete_oldest_key_if_max_size_breached(self):
if len(self) >= self.max_size:
oldest_key = self.access_history.pop(0)
del self[oldest_key]
def get_storage_info_from_endpoint(public_key: str, endpoint: str, x_tenant_id: str) -> StorageInfo:
resp = requests.get(f"{endpoint}/{x_tenant_id}").json()

48
tests/lru_test.py Normal file
View File

@ -0,0 +1,48 @@
from functools import lru_cache
import pytest
def func(callback):
return callback()
@pytest.fixture()
def fn(maxsize):
return lru_cache(maxsize)(func)
@pytest.fixture(params=[1, 2, 5])
def maxsize(request):
return request.param
class Callback:
def __init__(self, x):
self.initial_x = x
self.x = x
def __call__(self, *args, **kwargs):
self.x += 1
return self.x
def __hash__(self):
return hash(self.initial_x)
def test_adding_to_cache_within_maxsize_does_not_overwrite(fn, maxsize):
c = Callback(0)
for i in range(maxsize):
assert fn(c) == 1
assert fn(c) == 1
def test_adding_to_cache_more_than_maxsize_does_overwrite(fn, maxsize):
callbacks = [Callback(i) for i in range(maxsize)]
for i in range(maxsize):
assert fn(callbacks[i]) == i + 1
assert fn(Callback(maxsize)) == maxsize + 1
assert fn(callbacks[0]) == 2

View File

@ -1,47 +0,0 @@
from dataclasses import asdict
import pytest
from pyinfra.storage.storage_info import AzureStorageInfo, StorageInfo, DefaultStorageCache
@pytest.fixture
def storage_infos():
return [
AzureStorageInfo(connection_string="first", bucket_name="Thorsten"),
AzureStorageInfo(connection_string="first", bucket_name="Klaus"),
AzureStorageInfo(connection_string="second", bucket_name="Thorsten"),
AzureStorageInfo(connection_string="third", bucket_name="Klaus"),
]
@pytest.fixture
def storage_cache(max_size):
def get_connection_from_storage_info_mock(storage_info: StorageInfo):
return asdict(storage_info)
return DefaultStorageCache(max_size=max_size, get_value=get_connection_from_storage_info_mock)
class TestStorageCache:
def test_same_storage_info_has_same_hash(self, storage_infos):
assert storage_infos[0].__hash__() == storage_infos[1].__hash__()
@pytest.mark.parametrize("max_size", [4])
def test_same_connection_different_bucket_does_not_create_new_connection(self, storage_cache, storage_infos):
value = storage_cache[storage_infos[0]]
assert value == asdict((storage_infos[0]))
value = storage_cache[storage_infos[1]]
assert value == asdict((storage_infos[0]))
@pytest.mark.parametrize("max_size", [2])
def test_max_size_breached_removes_oldest_key(self, storage_cache, storage_infos):
storage_cache.set_key(storage_infos[0])
storage_cache.set_key(storage_infos[2])
storage_cache.set_key(storage_infos[3])
assert len(storage_cache) == 2
assert storage_infos[0] not in storage_cache
assert storage_infos[2] in storage_cache
assert storage_infos[3] in storage_cache