diff --git a/src/music_kraken/connection/connection.py b/src/music_kraken/connection/connection.py index 29ed3d3..2c25d05 100644 --- a/src/music_kraken/connection/connection.py +++ b/src/music_kraken/connection/connection.py @@ -4,10 +4,12 @@ import time from typing import List, Dict, Optional, Set from urllib.parse import urlparse, urlunsplit, ParseResult import copy +import inspect import requests import responses from tqdm import tqdm +import merge_args from .cache import Cache from .rotating import RotatingProxy @@ -125,32 +127,35 @@ class Connection: def request( self, - method: str, - try_count: int, - accepted_response_codes: set, url: str, - timeout: float, - headers: Optional[dict], + timeout: float = None, + headers: Optional[dict] = None, + try_count: int = 0, + accepted_response_codes: set = None, refer_from_origin: bool = True, raw_url: bool = False, sleep_after_404: float = None, is_heartbeat: bool = False, + disable_cache: bool = None, + method: str = None, name: str = "", **kwargs ) -> Optional[requests.Response]: - current_kwargs = copy.copy(locals) + if method is None: + raise AttributeError("method is not set.") + method = method.upper() + disable_cache = headers.get("Cache-Control", "").lower() == "no-cache" if disable_cache is None else disable_cache + accepted_response_codes = self.ACCEPTED_RESPONSE_CODES if accepted_response_codes is None else accepted_response_codes + + current_kwargs = copy.copy(locals()) parsed_url = urlparse(url) - headers = self._update_headers( headers=headers, refer_from_origin=refer_from_origin, url=parsed_url ) - disable_cache = headers.get("Cache-Control") == "no-cache" or kwargs.get("disable_cache", False) - - if name != "" and not disable_cache: cached = self.cache.get(name) @@ -225,105 +230,58 @@ class Connection: current_kwargs["try_count"] = current_kwargs.get("try_count", 0) + 1 return self.request(**current_kwargs) - def get( - self, - url: str, - refer_from_origin: bool = True, - stream: bool = False, - accepted_response_codes: set = None, - timeout: float = None, - headers: dict = None, - raw_url: bool = False, - **kwargs - ) -> Optional[requests.Response]: - if accepted_response_codes is None: - accepted_response_codes = self.ACCEPTED_RESPONSE_CODES - - r = self.request( + @merge_args(request) + def get(self, *args, **kwargs) -> Optional[requests.Response]: + return self.request( + *args, method="GET", - try_count=0, - accepted_response_codes=accepted_response_codes, - url=url, - timeout=timeout, - headers=headers, - raw_url=raw_url, - refer_from_origin=refer_from_origin, - stream=stream, **kwargs ) - if r is None: - self.LOGGER.warning(f"Max attempts ({self.TRIES}) exceeded for: GET:{url}") - return r + @merge_args(request) def post( self, - url: str, + *args, json: dict = None, - refer_from_origin: bool = True, - stream: bool = False, - accepted_response_codes: set = None, - timeout: float = None, - headers: dict = None, - raw_url: bool = False, **kwargs ) -> Optional[requests.Response]: r = self.request( + *args, method="POST", - try_count=0, - accepted_response_codes=accepted_response_codes or self.ACCEPTED_RESPONSE_CODES, - url=url, - timeout=timeout, - headers=headers, - refer_from_origin=refer_from_origin, - raw_url=raw_url, json=json, - stream=stream, **kwargs ) if r is None: - self.LOGGER.warning(f"Max attempts ({self.TRIES}) exceeded for: GET:{url}") self.LOGGER.warning(f"payload: {json}") return r + @merge_args(request) def stream_into( self, url: str, target: Target, - description: str = "download", - refer_from_origin: bool = True, - accepted_response_codes: set = None, - timeout: float = None, - headers: dict = None, - raw_url: bool = False, + name: str = "download", chunk_size: int = main_settings["chunk_size"], - try_count: int = 0, progress: int = 0, + method: str = "GET", **kwargs ) -> DownloadResult: + stream_kwargs = copy.copy(locals()) if progress > 0: - if headers is None: - headers = dict() + headers = dict() if headers is None else headers headers["Range"] = f"bytes={target.size}-" - if accepted_response_codes is None: - accepted_response_codes = self.ACCEPTED_RESPONSE_CODES - r = self.request( - method="GET", - try_count=0, - accepted_response_codes=accepted_response_codes, url=url, - timeout=timeout, - headers=headers, - raw_url=raw_url, - refer_from_origin=refer_from_origin, - stream=True, + name=name, + chunk_size=chunk_size, + method=method, **kwargs ) if r is None: - return DownloadResult(error_message=f"Could not establish connection to: {url}") + return DownloadResult(error_message=f"Could not establish a stream from: {url}") target.create_path() total_size = int(r.headers.get('content-length')) @@ -337,8 +295,7 @@ class Connection: > The internationally recommended unit symbol for the kilobyte is kB. """ - with tqdm(total=total_size - target.size, unit='B', unit_scale=True, unit_divisor=1024, - desc=description) as t: + with tqdm(total=total_size - target.size, unit='B', unit_scale=True, unit_divisor=1024, desc=name) as t: try: for chunk in r.iter_content(chunk_size=chunk_size): size = f.write(chunk) @@ -348,8 +305,7 @@ class Connection: except requests.exceptions.ConnectionError: if try_count >= self.TRIES: self.LOGGER.warning(f"Stream timed out at \"{url}\": to many retries, aborting.") - return DownloadResult( - error_message=f"Stream timed out from {url}, reducing the chunksize might help.") + return DownloadResult(error_message=f"Stream timed out from {url}, reducing the chunk_size might help.") self.LOGGER.warning(f"Stream timed out at \"{url}\": ({try_count}-{self.TRIES})") retry = True @@ -360,19 +316,6 @@ class Connection: if retry: self.LOGGER.warning(f"Retrying stream...") accepted_response_codes.add(206) - return self.stream_into( - url=url, - target=target, - description=description, - try_count=try_count + 1, - progress=progress, - accepted_response_codes=accepted_response_codes, - timeout=timeout, - headers=headers, - raw_url=raw_url, - refer_from_origin=refer_from_origin, - chunk_size=chunk_size, - **kwargs - ) + return self.stream_into(**stream_kwargs) return DownloadResult()