fix: encoding of cache

This commit is contained in:
Hazel 2024-04-26 14:04:44 +02:00
parent e77afa584b
commit 25eceb727b
2 changed files with 21 additions and 6 deletions

View File

@ -1,6 +1,6 @@
import json import json
from pathlib import Path from pathlib import Path
from dataclasses import dataclass from dataclasses import dataclass, field
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Optional from typing import List, Optional
from functools import lru_cache from functools import lru_cache
@ -18,6 +18,8 @@ class CacheAttribute:
created: datetime created: datetime
expires: datetime expires: datetime
additional_info: dict = field(default_factory=dict)
@property @property
def id(self): def id(self):
return f"{self.module}_{self.name}" return f"{self.module}_{self.name}"
@ -32,6 +34,12 @@ class CacheAttribute:
return self.__dict__ == other.__dict__ return self.__dict__ == other.__dict__
@dataclass
class CacheResult:
content: bytes
attribute: CacheAttribute
class Cache: class Cache:
def __init__(self, module: str, logger: logging.Logger): def __init__(self, module: str, logger: logging.Logger):
self.module = module self.module = module
@ -100,7 +108,7 @@ class Cache:
return True return True
def set(self, content: bytes, name: str, expires_in: float = 10, module: str = ""): def set(self, content: bytes, name: str, expires_in: float = 10, module: str = "", additional_info: dict = None):
""" """
:param content: :param content:
:param module: :param module:
@ -111,6 +119,7 @@ class Cache:
if name == "": if name == "":
return return
additional_info = additional_info or {}
module = self.module if module == "" else module module = self.module if module == "" else module
module_path = self._init_module(module) module_path = self._init_module(module)
@ -128,7 +137,7 @@ class Cache:
self.logger.debug(f"writing cache to {cache_path}") self.logger.debug(f"writing cache to {cache_path}")
content_file.write(content) content_file.write(content)
def get(self, name: str) -> Optional[bytes]: def get(self, name: str) -> Optional[CacheResult]:
path = fit_to_file_system(Path(self._dir, self.module, name), hidden_ok=True) path = fit_to_file_system(Path(self._dir, self.module, name), hidden_ok=True)
if not path.is_file(): if not path.is_file():
@ -140,7 +149,7 @@ class Cache:
return return
with path.open("rb") as f: with path.open("rb") as f:
return f.read() return CacheResult(content=f.read(), attribute=existing_attribute)
def clean(self): def clean(self):
keep = set() keep = set()

View File

@ -133,7 +133,9 @@ class Connection:
if self.cache.get(name) is not None and no_update_if_valid_exists: if self.cache.get(name) is not None and no_update_if_valid_exists:
return return
self.cache.set(r.content, name, expires_in=kwargs.get("expires_in", self.cache_expiring_duration), **n_kwargs) self.cache.set(r.content, name, expires_in=kwargs.get("expires_in", self.cache_expiring_duration), additional_info={
"encoding", r.encoding,
}, **n_kwargs)
def request( def request(
self, self,
@ -189,10 +191,14 @@ class Connection:
request_trace(f"{trace_string}\t[cached]") request_trace(f"{trace_string}\t[cached]")
with responses.RequestsMock() as resp: with responses.RequestsMock() as resp:
body = cached.content
if "encoding" in cached.additional_info:
body = body.decode(cached.additional_info["encoding"])
resp.add( resp.add(
method=method, method=method,
url=request_url, url=request_url,
body=cached, body=body,
) )
return requests.request(method=method, url=url, timeout=timeout, headers=headers, **kwargs) return requests.request(method=method, url=url, timeout=timeout, headers=headers, **kwargs)