diff --git a/music_kraken/connection/cache.py b/music_kraken/connection/cache.py index 232430f..cf4b797 100644 --- a/music_kraken/connection/cache.py +++ b/music_kraken/connection/cache.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime, timedelta from typing import List, Optional from functools import lru_cache @@ -17,6 +17,8 @@ class CacheAttribute: created: datetime expires: datetime + + additional_info: dict = field(default_factory=dict) @property def id(self): @@ -32,6 +34,12 @@ class CacheAttribute: return self.__dict__ == other.__dict__ +@dataclass +class CacheResult: + content: bytes + attribute: CacheAttribute + + class Cache: def __init__(self, module: str, logger: logging.Logger): self.module = module @@ -100,7 +108,7 @@ class Cache: return True - def set(self, content: bytes, name: str, expires_in: float = 10, module: str = ""): + def set(self, content: bytes, name: str, expires_in: float = 10, module: str = "", additional_info: dict = None): """ :param content: :param module: @@ -111,6 +119,7 @@ class Cache: if name == "": return + additional_info = additional_info or {} module = self.module if module == "" else module module_path = self._init_module(module) @@ -128,7 +137,7 @@ class Cache: self.logger.debug(f"writing cache to {cache_path}") content_file.write(content) - def get(self, name: str) -> Optional[bytes]: + def get(self, name: str) -> Optional[CacheResult]: path = fit_to_file_system(Path(self._dir, self.module, name), hidden_ok=True) if not path.is_file(): @@ -140,7 +149,7 @@ class Cache: return with path.open("rb") as f: - return f.read() + return CacheResult(content=f.read(), attribute=existing_attribute) def clean(self): keep = set() diff --git a/music_kraken/connection/connection.py b/music_kraken/connection/connection.py index f648fa1..de18879 100644 --- a/music_kraken/connection/connection.py +++ b/music_kraken/connection/connection.py @@ -133,7 +133,9 @@ class Connection: if self.cache.get(name) is not None and no_update_if_valid_exists: return - self.cache.set(r.content, name, expires_in=kwargs.get("expires_in", self.cache_expiring_duration), **n_kwargs) + self.cache.set(r.content, name, expires_in=kwargs.get("expires_in", self.cache_expiring_duration), additional_info={ + "encoding", r.encoding, + }, **n_kwargs) def request( self, @@ -189,10 +191,14 @@ class Connection: request_trace(f"{trace_string}\t[cached]") with responses.RequestsMock() as resp: + body = cached.content + if "encoding" in cached.additional_info: + body = body.decode(cached.additional_info["encoding"]) + resp.add( method=method, url=request_url, - body=cached, + body=body, ) return requests.request(method=method, url=url, timeout=timeout, headers=headers, **kwargs)