feat: implemented caching in the request method

This commit is contained in:
Hazel 2024-01-17 15:10:50 +01:00
parent 66f4ad3df5
commit 031f274d69
5 changed files with 77 additions and 42 deletions

View File

@ -4,8 +4,9 @@ from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Optional from typing import List, Optional
from functools import lru_cache from functools import lru_cache
import logging
from .config import main_settings from ..utils.config import main_settings
@dataclass @dataclass
@ -29,7 +30,10 @@ class CacheAttribute:
class Cache: class Cache:
def __init__(self): def __init__(self, module: str, logger: logging.Logger):
self.module = module
self.logger: logging.Logger = logger
self._dir = main_settings["cache_directory"] self._dir = main_settings["cache_directory"]
self.index = Path(self._dir, "index.json") self.index = Path(self._dir, "index.json")
@ -89,7 +93,7 @@ class Cache:
return True return True
def set(self, content: bytes, module: str, name: str, expires_in: int = 10): def set(self, content: bytes, name: str, expires_in: float = 10):
""" """
:param content: :param content:
:param module: :param module:
@ -97,28 +101,32 @@ class Cache:
:param expires_in: the unit is days :param expires_in: the unit is days
:return: :return:
""" """
if name == "":
return
module_path = self._init_module(module) module_path = self._init_module(self.module)
cache_attribute = CacheAttribute( cache_attribute = CacheAttribute(
module=module, module=self.module,
name=name, name=name,
created=datetime.now(), created=datetime.now(),
expires=datetime.now() + timedelta(days=expires_in), expires=datetime.now() + timedelta(days=expires_in),
) )
self._write_attribute(cache_attribute) self._write_attribute(cache_attribute)
with Path(module_path, name).open("wb") as content_file: cache_path = Path(module_path, name)
with cache_path.open("wb") as content_file:
self.logger.debug(f"writing cache to {cache_path}")
content_file.write(content) content_file.write(content)
def get(self, module: str, name: str) -> Optional[bytes]: def get(self, name: str) -> Optional[bytes]:
path = Path(self._dir, module, name) path = Path(self._dir, self.module, name)
if not path.is_file(): if not path.is_file():
return None return None
# check if it is outdated # check if it is outdated
existing_attribute: CacheAttribute = self._id_to_attribute[f"{module}_{name}"] existing_attribute: CacheAttribute = self._id_to_attribute[f"{self.module}_{name}"]
if not existing_attribute.is_valid: if not existing_attribute.is_valid:
return return

View File

@ -5,9 +5,12 @@ import logging
import threading import threading
import requests import requests
import responses
from responses import matchers
from tqdm import tqdm from tqdm import tqdm
from .rotating import RotatingProxy from .rotating import RotatingProxy
from .cache import Cache
from ..utils.config import main_settings from ..utils.config import main_settings
from ..utils.support_classes.download_result import DownloadResult from ..utils.support_classes.download_result import DownloadResult
from ..objects import Target from ..objects import Target
@ -25,13 +28,18 @@ class Connection:
accepted_response_codes: Set[int] = None, accepted_response_codes: Set[int] = None,
semantic_not_found: bool = True, semantic_not_found: bool = True,
sleep_after_404: float = 0.0, sleep_after_404: float = 0.0,
heartbeat_interval = 0, heartbeat_interval=0,
module: str = "general",
cache_expiring_duration: float = 10
): ):
if proxies is None: if proxies is None:
proxies = main_settings["proxies"] proxies = main_settings["proxies"]
if header_values is None: if header_values is None:
header_values = dict() header_values = dict()
self.cache: Cache = Cache(module=module, logger=logger)
self.cache_expiring_duration = cache_expiring_duration
self.HEADER_VALUES = header_values self.HEADER_VALUES = header_values
self.LOGGER = logger self.LOGGER = logger
@ -55,23 +63,24 @@ class Connection:
@property @property
def user_agent(self) -> str: def user_agent(self) -> str:
return self.session.headers.get("user-agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36") return self.session.headers.get("user-agent",
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36")
def start_heartbeat(self): def start_heartbeat(self):
if self.heartbeat_interval <= 0: if self.heartbeat_interval <= 0:
self.LOGGER.warning(f"Can't start a heartbeat with {self.heartbeat_interval}s in between.") self.LOGGER.warning(f"Can't start a heartbeat with {self.heartbeat_interval}s in between.")
self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop, args=(self.heartbeat_interval, ), daemon=True) self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop, args=(self.heartbeat_interval,),
daemon=True)
self.heartbeat_thread.start() self.heartbeat_thread.start()
def heartbeat_failed(self): def heartbeat_failed(self):
self.LOGGER.warning(f"I just died... (The heartbeat failed)") self.LOGGER.warning(f"I just died... (The heartbeat failed)")
def heartbeat(self): def heartbeat(self):
# Your code to send heartbeat requests goes here # Your code to send heartbeat requests goes here
print("the hearth is beating, but it needs to be implemented ;-;\nFuck youuuu for setting heartbeat in the constructor to true, but not implementing the method Connection.hearbeat()") print(
"the hearth is beating, but it needs to be implemented ;-;\nFuck youuuu for setting heartbeat in the constructor to true, but not implementing the method Connection.hearbeat()")
def _heartbeat_loop(self, interval: float): def _heartbeat_loop(self, interval: float):
def heartbeat_wrapper(): def heartbeat_wrapper():
@ -85,8 +94,6 @@ class Connection:
heartbeat_wrapper() heartbeat_wrapper()
time.sleep(interval) time.sleep(interval)
def base_url(self, url: ParseResult = None): def base_url(self, url: ParseResult = None):
if url is None: if url is None:
url = self.HOST url = self.HOST
@ -119,9 +126,12 @@ class Connection:
return headers return headers
def _request( def save(self, r: requests.Response, name: str, **kwargs):
self.cache.set(r.content, name, expires_in=kwargs.get("expires_in", self.cache_expiring_duration))
def request(
self, self,
request: Callable, method: str,
try_count: int, try_count: int,
accepted_response_codes: set, accepted_response_codes: set,
url: str, url: str,
@ -131,8 +141,20 @@ class Connection:
raw_url: bool = False, raw_url: bool = False,
sleep_after_404: float = None, sleep_after_404: float = None,
is_heartbeat: bool = False, is_heartbeat: bool = False,
name: str = "",
**kwargs **kwargs
) -> Optional[requests.Response]: ) -> Optional[requests.Response]:
if name != "":
cached = self.cache.get(name)
with responses.RequestsMock() as resp:
resp.add(
method=method,
url=url,
body=cached,
)
return requests.request(method=method, url=url, timeout=timeout, headers=headers, **kwargs)
if sleep_after_404 is None: if sleep_after_404 is None:
sleep_after_404 = self.sleep_after_404 sleep_after_404 = self.sleep_after_404
if try_count >= self.TRIES: if try_count >= self.TRIES:
@ -158,9 +180,10 @@ class Connection:
while self.session_is_occupied and not is_heartbeat: while self.session_is_occupied and not is_heartbeat:
pass pass
r: requests.Response = request(request_url, timeout=timeout, headers=headers, **kwargs) r: requests.Response = requests.request(method=method, url=url, timeout=timeout, headers=headers, **kwargs)
if r.status_code in accepted_response_codes: if r.status_code in accepted_response_codes:
self.save(r, name, **kwargs)
return r return r
if self.SEMANTIC_NOT_FOUND and r.status_code == 404: if self.SEMANTIC_NOT_FOUND and r.status_code == 404:
@ -187,15 +210,16 @@ class Connection:
if self.heartbeat_interval > 0 and self.heartbeat_thread is None: if self.heartbeat_interval > 0 and self.heartbeat_thread is None:
self.start_heartbeat() self.start_heartbeat()
return self._request( return self.request(
request=request, method=method,
try_count=try_count+1, try_count=try_count + 1,
accepted_response_codes=accepted_response_codes, accepted_response_codes=accepted_response_codes,
url=url, url=url,
timeout=timeout, timeout=timeout,
headers=headers, headers=headers,
sleep_after_404=sleep_after_404, sleep_after_404=sleep_after_404,
is_heartbeat=is_heartbeat, is_heartbeat=is_heartbeat,
name=name,
**kwargs **kwargs
) )
@ -213,8 +237,8 @@ class Connection:
if accepted_response_codes is None: if accepted_response_codes is None:
accepted_response_codes = self.ACCEPTED_RESPONSE_CODES accepted_response_codes = self.ACCEPTED_RESPONSE_CODES
r = self._request( r = self.request(
request=self.session.get, method="GET",
try_count=0, try_count=0,
accepted_response_codes=accepted_response_codes, accepted_response_codes=accepted_response_codes,
url=url, url=url,
@ -241,8 +265,8 @@ class Connection:
raw_url: bool = False, raw_url: bool = False,
**kwargs **kwargs
) -> Optional[requests.Response]: ) -> Optional[requests.Response]:
r = self._request( r = self.request(
request=self.session.post, method="POST",
try_count=0, try_count=0,
accepted_response_codes=accepted_response_codes or self.ACCEPTED_RESPONSE_CODES, accepted_response_codes=accepted_response_codes or self.ACCEPTED_RESPONSE_CODES,
url=url, url=url,
@ -282,9 +306,9 @@ class Connection:
if accepted_response_codes is None: if accepted_response_codes is None:
accepted_response_codes = self.ACCEPTED_RESPONSE_CODES accepted_response_codes = self.ACCEPTED_RESPONSE_CODES
r = self._request( r = self.request(
request=self.session.get, method="GET",
try_count=0, try_count=0,
accepted_response_codes=accepted_response_codes, accepted_response_codes=accepted_response_codes,
url=url, url=url,
@ -310,8 +334,9 @@ class Connection:
https://en.wikipedia.org/wiki/Kilobyte https://en.wikipedia.org/wiki/Kilobyte
> The internationally recommended unit symbol for the kilobyte is kB. > The internationally recommended unit symbol for the kilobyte is kB.
""" """
with tqdm(total=total_size-target.size, unit='B', unit_scale=True, unit_divisor=1024, desc=description) as t: with tqdm(total=total_size - target.size, unit='B', unit_scale=True, unit_divisor=1024,
desc=description) as t:
try: try:
for chunk in r.iter_content(chunk_size=chunk_size): for chunk in r.iter_content(chunk_size=chunk_size):
size = f.write(chunk) size = f.write(chunk)
@ -321,7 +346,8 @@ class Connection:
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
if try_count >= self.TRIES: if try_count >= self.TRIES:
self.LOGGER.warning(f"Stream timed out at \"{url}\": to many retries, aborting.") self.LOGGER.warning(f"Stream timed out at \"{url}\": to many retries, aborting.")
return DownloadResult(error_message=f"Stream timed out from {url}, reducing the chunksize might help.") return DownloadResult(
error_message=f"Stream timed out from {url}, reducing the chunksize might help.")
self.LOGGER.warning(f"Stream timed out at \"{url}\": ({try_count}-{self.TRIES})") self.LOGGER.warning(f"Stream timed out at \"{url}\": ({try_count}-{self.TRIES})")
retry = True retry = True
@ -329,15 +355,14 @@ class Connection:
if total_size > progress: if total_size > progress:
retry = True retry = True
if retry: if retry:
self.LOGGER.warning(f"Retrying stream...") self.LOGGER.warning(f"Retrying stream...")
accepted_response_codes.add(206) accepted_response_codes.add(206)
return self.stream_into( return self.stream_into(
url = url, url=url,
target = target, target=target,
description = description, description=description,
try_count=try_count+1, try_count=try_count + 1,
progress=progress, progress=progress,
accepted_response_codes=accepted_response_codes, accepted_response_codes=accepted_response_codes,
timeout=timeout, timeout=timeout,

View File

@ -356,6 +356,5 @@ class Bandcamp(Page):
def download_song_to_target(self, source: Source, target: Target, desc: str = None) -> DownloadResult: def download_song_to_target(self, source: Source, target: Target, desc: str = None) -> DownloadResult:
if source.audio_url is None: if source.audio_url is None:
print(source)
return DownloadResult(error_message="Couldn't find download link.") return DownloadResult(error_message="Couldn't find download link.")
return self.connection.stream_into(url=source.audio_url, target=target, description=desc) return self.connection.stream_into(url=source.audio_url, target=target, description=desc)

View File

@ -59,7 +59,8 @@ class YoutubeMusicConnection(Connection):
heartbeat_interval=113.25, heartbeat_interval=113.25,
header_values={ header_values={
"Accept-Language": accept_language "Accept-Language": accept_language
} },
module="youtube_music",
) )
# cookie consent for youtube # cookie consent for youtube
@ -161,8 +162,10 @@ class YoutubeMusic(SuperYouTube):
# save cookies in settings # save cookies in settings
youtube_settings["youtube_music_consent_cookies"] = cookie_dict youtube_settings["youtube_music_consent_cookies"] = cookie_dict
else:
self.connection.save(r, "index.html")
r = self.connection.get("https://music.youtube.com/") r = self.connection.get("https://music.youtube.com/", name="index.html")
if r is None: if r is None:
return return

View File

@ -4,7 +4,7 @@ from .config import main_settings
DEBUG = True DEBUG = True
DEBUG_LOGGING = DEBUG and True DEBUG_LOGGING = DEBUG and True
DEBUG_YOUTUBE_INITIALIZING = DEBUG and False DEBUG_YOUTUBE_INITIALIZING = DEBUG and True
DEBUG_PAGES = DEBUG and False DEBUG_PAGES = DEBUG and False
if DEBUG: if DEBUG: