feat: implemented caching in the request method
This commit is contained in:
parent
66f4ad3df5
commit
031f274d69
@ -4,8 +4,9 @@ from dataclasses import dataclass
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
import logging
|
||||||
|
|
||||||
from .config import main_settings
|
from ..utils.config import main_settings
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -29,7 +30,10 @@ class CacheAttribute:
|
|||||||
|
|
||||||
|
|
||||||
class Cache:
|
class Cache:
|
||||||
def __init__(self):
|
def __init__(self, module: str, logger: logging.Logger):
|
||||||
|
self.module = module
|
||||||
|
self.logger: logging.Logger = logger
|
||||||
|
|
||||||
self._dir = main_settings["cache_directory"]
|
self._dir = main_settings["cache_directory"]
|
||||||
self.index = Path(self._dir, "index.json")
|
self.index = Path(self._dir, "index.json")
|
||||||
|
|
||||||
@ -89,7 +93,7 @@ class Cache:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def set(self, content: bytes, module: str, name: str, expires_in: int = 10):
|
def set(self, content: bytes, name: str, expires_in: float = 10):
|
||||||
"""
|
"""
|
||||||
:param content:
|
:param content:
|
||||||
:param module:
|
:param module:
|
||||||
@ -97,28 +101,32 @@ class Cache:
|
|||||||
:param expires_in: the unit is days
|
:param expires_in: the unit is days
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
if name == "":
|
||||||
|
return
|
||||||
|
|
||||||
module_path = self._init_module(module)
|
module_path = self._init_module(self.module)
|
||||||
|
|
||||||
cache_attribute = CacheAttribute(
|
cache_attribute = CacheAttribute(
|
||||||
module=module,
|
module=self.module,
|
||||||
name=name,
|
name=name,
|
||||||
created=datetime.now(),
|
created=datetime.now(),
|
||||||
expires=datetime.now() + timedelta(days=expires_in),
|
expires=datetime.now() + timedelta(days=expires_in),
|
||||||
)
|
)
|
||||||
self._write_attribute(cache_attribute)
|
self._write_attribute(cache_attribute)
|
||||||
|
|
||||||
with Path(module_path, name).open("wb") as content_file:
|
cache_path = Path(module_path, name)
|
||||||
|
with cache_path.open("wb") as content_file:
|
||||||
|
self.logger.debug(f"writing cache to {cache_path}")
|
||||||
content_file.write(content)
|
content_file.write(content)
|
||||||
|
|
||||||
def get(self, module: str, name: str) -> Optional[bytes]:
|
def get(self, name: str) -> Optional[bytes]:
|
||||||
path = Path(self._dir, module, name)
|
path = Path(self._dir, self.module, name)
|
||||||
|
|
||||||
if not path.is_file():
|
if not path.is_file():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# check if it is outdated
|
# check if it is outdated
|
||||||
existing_attribute: CacheAttribute = self._id_to_attribute[f"{module}_{name}"]
|
existing_attribute: CacheAttribute = self._id_to_attribute[f"{self.module}_{name}"]
|
||||||
if not existing_attribute.is_valid:
|
if not existing_attribute.is_valid:
|
||||||
return
|
return
|
||||||
|
|
@ -5,9 +5,12 @@ import logging
|
|||||||
|
|
||||||
import threading
|
import threading
|
||||||
import requests
|
import requests
|
||||||
|
import responses
|
||||||
|
from responses import matchers
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .rotating import RotatingProxy
|
from .rotating import RotatingProxy
|
||||||
|
from .cache import Cache
|
||||||
from ..utils.config import main_settings
|
from ..utils.config import main_settings
|
||||||
from ..utils.support_classes.download_result import DownloadResult
|
from ..utils.support_classes.download_result import DownloadResult
|
||||||
from ..objects import Target
|
from ..objects import Target
|
||||||
@ -26,12 +29,17 @@ class Connection:
|
|||||||
semantic_not_found: bool = True,
|
semantic_not_found: bool = True,
|
||||||
sleep_after_404: float = 0.0,
|
sleep_after_404: float = 0.0,
|
||||||
heartbeat_interval=0,
|
heartbeat_interval=0,
|
||||||
|
module: str = "general",
|
||||||
|
cache_expiring_duration: float = 10
|
||||||
):
|
):
|
||||||
if proxies is None:
|
if proxies is None:
|
||||||
proxies = main_settings["proxies"]
|
proxies = main_settings["proxies"]
|
||||||
if header_values is None:
|
if header_values is None:
|
||||||
header_values = dict()
|
header_values = dict()
|
||||||
|
|
||||||
|
self.cache: Cache = Cache(module=module, logger=logger)
|
||||||
|
self.cache_expiring_duration = cache_expiring_duration
|
||||||
|
|
||||||
self.HEADER_VALUES = header_values
|
self.HEADER_VALUES = header_values
|
||||||
|
|
||||||
self.LOGGER = logger
|
self.LOGGER = logger
|
||||||
@ -55,23 +63,24 @@ class Connection:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def user_agent(self) -> str:
|
def user_agent(self) -> str:
|
||||||
return self.session.headers.get("user-agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36")
|
return self.session.headers.get("user-agent",
|
||||||
|
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36")
|
||||||
|
|
||||||
def start_heartbeat(self):
|
def start_heartbeat(self):
|
||||||
if self.heartbeat_interval <= 0:
|
if self.heartbeat_interval <= 0:
|
||||||
self.LOGGER.warning(f"Can't start a heartbeat with {self.heartbeat_interval}s in between.")
|
self.LOGGER.warning(f"Can't start a heartbeat with {self.heartbeat_interval}s in between.")
|
||||||
|
|
||||||
self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop, args=(self.heartbeat_interval, ), daemon=True)
|
self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop, args=(self.heartbeat_interval,),
|
||||||
|
daemon=True)
|
||||||
self.heartbeat_thread.start()
|
self.heartbeat_thread.start()
|
||||||
|
|
||||||
def heartbeat_failed(self):
|
def heartbeat_failed(self):
|
||||||
self.LOGGER.warning(f"I just died... (The heartbeat failed)")
|
self.LOGGER.warning(f"I just died... (The heartbeat failed)")
|
||||||
|
|
||||||
|
|
||||||
def heartbeat(self):
|
def heartbeat(self):
|
||||||
# Your code to send heartbeat requests goes here
|
# Your code to send heartbeat requests goes here
|
||||||
print("the hearth is beating, but it needs to be implemented ;-;\nFuck youuuu for setting heartbeat in the constructor to true, but not implementing the method Connection.hearbeat()")
|
print(
|
||||||
|
"the hearth is beating, but it needs to be implemented ;-;\nFuck youuuu for setting heartbeat in the constructor to true, but not implementing the method Connection.hearbeat()")
|
||||||
|
|
||||||
def _heartbeat_loop(self, interval: float):
|
def _heartbeat_loop(self, interval: float):
|
||||||
def heartbeat_wrapper():
|
def heartbeat_wrapper():
|
||||||
@ -85,8 +94,6 @@ class Connection:
|
|||||||
heartbeat_wrapper()
|
heartbeat_wrapper()
|
||||||
time.sleep(interval)
|
time.sleep(interval)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def base_url(self, url: ParseResult = None):
|
def base_url(self, url: ParseResult = None):
|
||||||
if url is None:
|
if url is None:
|
||||||
url = self.HOST
|
url = self.HOST
|
||||||
@ -119,9 +126,12 @@ class Connection:
|
|||||||
|
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
def _request(
|
def save(self, r: requests.Response, name: str, **kwargs):
|
||||||
|
self.cache.set(r.content, name, expires_in=kwargs.get("expires_in", self.cache_expiring_duration))
|
||||||
|
|
||||||
|
def request(
|
||||||
self,
|
self,
|
||||||
request: Callable,
|
method: str,
|
||||||
try_count: int,
|
try_count: int,
|
||||||
accepted_response_codes: set,
|
accepted_response_codes: set,
|
||||||
url: str,
|
url: str,
|
||||||
@ -131,8 +141,20 @@ class Connection:
|
|||||||
raw_url: bool = False,
|
raw_url: bool = False,
|
||||||
sleep_after_404: float = None,
|
sleep_after_404: float = None,
|
||||||
is_heartbeat: bool = False,
|
is_heartbeat: bool = False,
|
||||||
|
name: str = "",
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Optional[requests.Response]:
|
) -> Optional[requests.Response]:
|
||||||
|
if name != "":
|
||||||
|
cached = self.cache.get(name)
|
||||||
|
|
||||||
|
with responses.RequestsMock() as resp:
|
||||||
|
resp.add(
|
||||||
|
method=method,
|
||||||
|
url=url,
|
||||||
|
body=cached,
|
||||||
|
)
|
||||||
|
return requests.request(method=method, url=url, timeout=timeout, headers=headers, **kwargs)
|
||||||
|
|
||||||
if sleep_after_404 is None:
|
if sleep_after_404 is None:
|
||||||
sleep_after_404 = self.sleep_after_404
|
sleep_after_404 = self.sleep_after_404
|
||||||
if try_count >= self.TRIES:
|
if try_count >= self.TRIES:
|
||||||
@ -158,9 +180,10 @@ class Connection:
|
|||||||
while self.session_is_occupied and not is_heartbeat:
|
while self.session_is_occupied and not is_heartbeat:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
r: requests.Response = request(request_url, timeout=timeout, headers=headers, **kwargs)
|
r: requests.Response = requests.request(method=method, url=url, timeout=timeout, headers=headers, **kwargs)
|
||||||
|
|
||||||
if r.status_code in accepted_response_codes:
|
if r.status_code in accepted_response_codes:
|
||||||
|
self.save(r, name, **kwargs)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
if self.SEMANTIC_NOT_FOUND and r.status_code == 404:
|
if self.SEMANTIC_NOT_FOUND and r.status_code == 404:
|
||||||
@ -187,8 +210,8 @@ class Connection:
|
|||||||
if self.heartbeat_interval > 0 and self.heartbeat_thread is None:
|
if self.heartbeat_interval > 0 and self.heartbeat_thread is None:
|
||||||
self.start_heartbeat()
|
self.start_heartbeat()
|
||||||
|
|
||||||
return self._request(
|
return self.request(
|
||||||
request=request,
|
method=method,
|
||||||
try_count=try_count + 1,
|
try_count=try_count + 1,
|
||||||
accepted_response_codes=accepted_response_codes,
|
accepted_response_codes=accepted_response_codes,
|
||||||
url=url,
|
url=url,
|
||||||
@ -196,6 +219,7 @@ class Connection:
|
|||||||
headers=headers,
|
headers=headers,
|
||||||
sleep_after_404=sleep_after_404,
|
sleep_after_404=sleep_after_404,
|
||||||
is_heartbeat=is_heartbeat,
|
is_heartbeat=is_heartbeat,
|
||||||
|
name=name,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -213,8 +237,8 @@ class Connection:
|
|||||||
if accepted_response_codes is None:
|
if accepted_response_codes is None:
|
||||||
accepted_response_codes = self.ACCEPTED_RESPONSE_CODES
|
accepted_response_codes = self.ACCEPTED_RESPONSE_CODES
|
||||||
|
|
||||||
r = self._request(
|
r = self.request(
|
||||||
request=self.session.get,
|
method="GET",
|
||||||
try_count=0,
|
try_count=0,
|
||||||
accepted_response_codes=accepted_response_codes,
|
accepted_response_codes=accepted_response_codes,
|
||||||
url=url,
|
url=url,
|
||||||
@ -241,8 +265,8 @@ class Connection:
|
|||||||
raw_url: bool = False,
|
raw_url: bool = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Optional[requests.Response]:
|
) -> Optional[requests.Response]:
|
||||||
r = self._request(
|
r = self.request(
|
||||||
request=self.session.post,
|
method="POST",
|
||||||
try_count=0,
|
try_count=0,
|
||||||
accepted_response_codes=accepted_response_codes or self.ACCEPTED_RESPONSE_CODES,
|
accepted_response_codes=accepted_response_codes or self.ACCEPTED_RESPONSE_CODES,
|
||||||
url=url,
|
url=url,
|
||||||
@ -283,8 +307,8 @@ class Connection:
|
|||||||
if accepted_response_codes is None:
|
if accepted_response_codes is None:
|
||||||
accepted_response_codes = self.ACCEPTED_RESPONSE_CODES
|
accepted_response_codes = self.ACCEPTED_RESPONSE_CODES
|
||||||
|
|
||||||
r = self._request(
|
r = self.request(
|
||||||
request=self.session.get,
|
method="GET",
|
||||||
try_count=0,
|
try_count=0,
|
||||||
accepted_response_codes=accepted_response_codes,
|
accepted_response_codes=accepted_response_codes,
|
||||||
url=url,
|
url=url,
|
||||||
@ -311,7 +335,8 @@ class Connection:
|
|||||||
> The internationally recommended unit symbol for the kilobyte is kB.
|
> The internationally recommended unit symbol for the kilobyte is kB.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with tqdm(total=total_size-target.size, unit='B', unit_scale=True, unit_divisor=1024, desc=description) as t:
|
with tqdm(total=total_size - target.size, unit='B', unit_scale=True, unit_divisor=1024,
|
||||||
|
desc=description) as t:
|
||||||
try:
|
try:
|
||||||
for chunk in r.iter_content(chunk_size=chunk_size):
|
for chunk in r.iter_content(chunk_size=chunk_size):
|
||||||
size = f.write(chunk)
|
size = f.write(chunk)
|
||||||
@ -321,7 +346,8 @@ class Connection:
|
|||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
if try_count >= self.TRIES:
|
if try_count >= self.TRIES:
|
||||||
self.LOGGER.warning(f"Stream timed out at \"{url}\": to many retries, aborting.")
|
self.LOGGER.warning(f"Stream timed out at \"{url}\": to many retries, aborting.")
|
||||||
return DownloadResult(error_message=f"Stream timed out from {url}, reducing the chunksize might help.")
|
return DownloadResult(
|
||||||
|
error_message=f"Stream timed out from {url}, reducing the chunksize might help.")
|
||||||
|
|
||||||
self.LOGGER.warning(f"Stream timed out at \"{url}\": ({try_count}-{self.TRIES})")
|
self.LOGGER.warning(f"Stream timed out at \"{url}\": ({try_count}-{self.TRIES})")
|
||||||
retry = True
|
retry = True
|
||||||
@ -329,7 +355,6 @@ class Connection:
|
|||||||
if total_size > progress:
|
if total_size > progress:
|
||||||
retry = True
|
retry = True
|
||||||
|
|
||||||
|
|
||||||
if retry:
|
if retry:
|
||||||
self.LOGGER.warning(f"Retrying stream...")
|
self.LOGGER.warning(f"Retrying stream...")
|
||||||
accepted_response_codes.add(206)
|
accepted_response_codes.add(206)
|
||||||
|
@ -356,6 +356,5 @@ class Bandcamp(Page):
|
|||||||
|
|
||||||
def download_song_to_target(self, source: Source, target: Target, desc: str = None) -> DownloadResult:
|
def download_song_to_target(self, source: Source, target: Target, desc: str = None) -> DownloadResult:
|
||||||
if source.audio_url is None:
|
if source.audio_url is None:
|
||||||
print(source)
|
|
||||||
return DownloadResult(error_message="Couldn't find download link.")
|
return DownloadResult(error_message="Couldn't find download link.")
|
||||||
return self.connection.stream_into(url=source.audio_url, target=target, description=desc)
|
return self.connection.stream_into(url=source.audio_url, target=target, description=desc)
|
||||||
|
@ -59,7 +59,8 @@ class YoutubeMusicConnection(Connection):
|
|||||||
heartbeat_interval=113.25,
|
heartbeat_interval=113.25,
|
||||||
header_values={
|
header_values={
|
||||||
"Accept-Language": accept_language
|
"Accept-Language": accept_language
|
||||||
}
|
},
|
||||||
|
module="youtube_music",
|
||||||
)
|
)
|
||||||
|
|
||||||
# cookie consent for youtube
|
# cookie consent for youtube
|
||||||
@ -161,8 +162,10 @@ class YoutubeMusic(SuperYouTube):
|
|||||||
|
|
||||||
# save cookies in settings
|
# save cookies in settings
|
||||||
youtube_settings["youtube_music_consent_cookies"] = cookie_dict
|
youtube_settings["youtube_music_consent_cookies"] = cookie_dict
|
||||||
|
else:
|
||||||
|
self.connection.save(r, "index.html")
|
||||||
|
|
||||||
r = self.connection.get("https://music.youtube.com/")
|
r = self.connection.get("https://music.youtube.com/", name="index.html")
|
||||||
if r is None:
|
if r is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from .config import main_settings
|
|||||||
|
|
||||||
DEBUG = True
|
DEBUG = True
|
||||||
DEBUG_LOGGING = DEBUG and True
|
DEBUG_LOGGING = DEBUG and True
|
||||||
DEBUG_YOUTUBE_INITIALIZING = DEBUG and False
|
DEBUG_YOUTUBE_INITIALIZING = DEBUG and True
|
||||||
DEBUG_PAGES = DEBUG and False
|
DEBUG_PAGES = DEBUG and False
|
||||||
|
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
|
Loading…
Reference in New Issue
Block a user