Refactoring
Replace custom storage caching logic with LRU decorator
This commit is contained in:
parent
eafcd90260
commit
892a803726
@ -16,7 +16,7 @@ from pyinfra.storage.storage_info import (
|
||||
get_storage_info_from_config,
|
||||
get_storage_info_from_endpoint,
|
||||
StorageInfo,
|
||||
DefaultStorageCache,
|
||||
get_storage_from_storage_info,
|
||||
)
|
||||
|
||||
logger = logging.getLogger()
|
||||
@ -27,7 +27,6 @@ class PayloadProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
default_storage_info: StorageInfo,
|
||||
storage_cache: DefaultStorageCache,
|
||||
get_storage_info_from_tenant_id,
|
||||
payload_parser: QueueMessagePayloadParser,
|
||||
payload_formatter: QueueMessagePayloadFormatter,
|
||||
@ -52,7 +51,6 @@ class PayloadProcessor:
|
||||
|
||||
self.get_storage_info_from_tenant_id = get_storage_info_from_tenant_id
|
||||
self.default_storage_info = default_storage_info
|
||||
self.storages = storage_cache
|
||||
|
||||
def __call__(self, queue_message_payload: dict) -> dict:
|
||||
"""Processes a queue message payload.
|
||||
@ -77,7 +75,7 @@ class PayloadProcessor:
|
||||
logger.info(f"Processing {asdict(payload)} ...")
|
||||
|
||||
storage_info = self._get_storage_info(payload.x_tenant_id)
|
||||
storage = self.storages[storage_info]
|
||||
storage = get_storage_from_storage_info(storage_info)
|
||||
bucket = storage_info.bucket_name
|
||||
|
||||
download_file_to_process = make_downloader(
|
||||
@ -112,7 +110,6 @@ def make_payload_processor(data_processor: Callable, config: Config = None) -> P
|
||||
"""Produces payload processor for queue manager."""
|
||||
config = config or get_config()
|
||||
|
||||
storage_cache: DefaultStorageCache = DefaultStorageCache(max_size=10)
|
||||
default_storage_info: StorageInfo = get_storage_info_from_config(config)
|
||||
get_storage_info_from_tenant_id = partial(
|
||||
get_storage_info_from_endpoint,
|
||||
@ -127,7 +124,6 @@ def make_payload_processor(data_processor: Callable, config: Config = None) -> P
|
||||
|
||||
return PayloadProcessor(
|
||||
default_storage_info,
|
||||
storage_cache,
|
||||
get_storage_info_from_tenant_id,
|
||||
payload_parser,
|
||||
payload_formatter,
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
import requests
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
@ -72,37 +71,6 @@ def get_storage_from_storage_info(storage_info: StorageInfo) -> Storage:
|
||||
raise UnknownStorageBackend()
|
||||
|
||||
|
||||
class DefaultStorageCache(dict):
|
||||
def __init__(self, max_size=None, get_value: Callable = get_storage_from_storage_info, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.max_size = max_size
|
||||
self.get_value = get_value
|
||||
self.access_history = []
|
||||
|
||||
def __getitem__(self, key: StorageInfo):
|
||||
self.access_history.append(key)
|
||||
value = super().get(key)
|
||||
if not value:
|
||||
self.__setitem__(key, value)
|
||||
value = self[key]
|
||||
|
||||
return value
|
||||
|
||||
def __setitem__(self, key: StorageInfo, value: Storage = None):
|
||||
self._delete_oldest_key_if_max_size_breached()
|
||||
value = value or self.get_value(key)
|
||||
super().__setitem__(key, value)
|
||||
self.access_history.append(key)
|
||||
|
||||
def set_key(self, key: StorageInfo):
|
||||
self.__setitem__(key)
|
||||
|
||||
def _delete_oldest_key_if_max_size_breached(self):
|
||||
if len(self) >= self.max_size:
|
||||
oldest_key = self.access_history.pop(0)
|
||||
del self[oldest_key]
|
||||
|
||||
|
||||
def get_storage_info_from_endpoint(public_key: str, endpoint: str, x_tenant_id: str) -> StorageInfo:
|
||||
resp = requests.get(f"{endpoint}/{x_tenant_id}").json()
|
||||
|
||||
|
||||
48
tests/lru_test.py
Normal file
48
tests/lru_test.py
Normal file
@ -0,0 +1,48 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def func(callback):
|
||||
return callback()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def fn(maxsize):
|
||||
return lru_cache(maxsize)(func)
|
||||
|
||||
|
||||
@pytest.fixture(params=[1, 2, 5])
|
||||
def maxsize(request):
|
||||
return request.param
|
||||
|
||||
|
||||
class Callback:
|
||||
def __init__(self, x):
|
||||
self.initial_x = x
|
||||
self.x = x
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self.x += 1
|
||||
return self.x
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.initial_x)
|
||||
|
||||
|
||||
def test_adding_to_cache_within_maxsize_does_not_overwrite(fn, maxsize):
|
||||
c = Callback(0)
|
||||
for i in range(maxsize):
|
||||
assert fn(c) == 1
|
||||
assert fn(c) == 1
|
||||
|
||||
|
||||
def test_adding_to_cache_more_than_maxsize_does_overwrite(fn, maxsize):
|
||||
|
||||
callbacks = [Callback(i) for i in range(maxsize)]
|
||||
|
||||
for i in range(maxsize):
|
||||
assert fn(callbacks[i]) == i + 1
|
||||
|
||||
assert fn(Callback(maxsize)) == maxsize + 1
|
||||
assert fn(callbacks[0]) == 2
|
||||
@ -1,47 +0,0 @@
|
||||
from dataclasses import asdict
|
||||
|
||||
import pytest
|
||||
|
||||
from pyinfra.storage.storage_info import AzureStorageInfo, StorageInfo, DefaultStorageCache
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage_infos():
|
||||
return [
|
||||
AzureStorageInfo(connection_string="first", bucket_name="Thorsten"),
|
||||
AzureStorageInfo(connection_string="first", bucket_name="Klaus"),
|
||||
AzureStorageInfo(connection_string="second", bucket_name="Thorsten"),
|
||||
AzureStorageInfo(connection_string="third", bucket_name="Klaus"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage_cache(max_size):
|
||||
def get_connection_from_storage_info_mock(storage_info: StorageInfo):
|
||||
return asdict(storage_info)
|
||||
|
||||
return DefaultStorageCache(max_size=max_size, get_value=get_connection_from_storage_info_mock)
|
||||
|
||||
|
||||
class TestStorageCache:
|
||||
def test_same_storage_info_has_same_hash(self, storage_infos):
|
||||
assert storage_infos[0].__hash__() == storage_infos[1].__hash__()
|
||||
|
||||
@pytest.mark.parametrize("max_size", [4])
|
||||
def test_same_connection_different_bucket_does_not_create_new_connection(self, storage_cache, storage_infos):
|
||||
value = storage_cache[storage_infos[0]]
|
||||
assert value == asdict((storage_infos[0]))
|
||||
|
||||
value = storage_cache[storage_infos[1]]
|
||||
assert value == asdict((storage_infos[0]))
|
||||
|
||||
@pytest.mark.parametrize("max_size", [2])
|
||||
def test_max_size_breached_removes_oldest_key(self, storage_cache, storage_infos):
|
||||
storage_cache.set_key(storage_infos[0])
|
||||
storage_cache.set_key(storage_infos[2])
|
||||
storage_cache.set_key(storage_infos[3])
|
||||
|
||||
assert len(storage_cache) == 2
|
||||
assert storage_infos[0] not in storage_cache
|
||||
assert storage_infos[2] in storage_cache
|
||||
assert storage_infos[3] in storage_cache
|
||||
Loading…
x
Reference in New Issue
Block a user