feat: added type annotation merging for conection function

This commit is contained in:
Hazel 2024-02-28 11:17:07 +01:00
parent 7f6db2781d
commit 7f2abdf572

View File

@ -4,10 +4,12 @@ import time
from typing import List, Dict, Optional, Set from typing import List, Dict, Optional, Set
from urllib.parse import urlparse, urlunsplit, ParseResult from urllib.parse import urlparse, urlunsplit, ParseResult
import copy import copy
import inspect
import requests import requests
import responses import responses
from tqdm import tqdm from tqdm import tqdm
import merge_args
from .cache import Cache from .cache import Cache
from .rotating import RotatingProxy from .rotating import RotatingProxy
@ -125,32 +127,35 @@ class Connection:
def request( def request(
self, self,
method: str,
try_count: int,
accepted_response_codes: set,
url: str, url: str,
timeout: float, timeout: float = None,
headers: Optional[dict], headers: Optional[dict] = None,
try_count: int = 0,
accepted_response_codes: set = None,
refer_from_origin: bool = True, refer_from_origin: bool = True,
raw_url: bool = False, raw_url: bool = False,
sleep_after_404: float = None, sleep_after_404: float = None,
is_heartbeat: bool = False, is_heartbeat: bool = False,
disable_cache: bool = None,
method: str = None,
name: str = "", name: str = "",
**kwargs **kwargs
) -> Optional[requests.Response]: ) -> Optional[requests.Response]:
current_kwargs = copy.copy(locals) if method is None:
raise AttributeError("method is not set.")
method = method.upper()
disable_cache = headers.get("Cache-Control", "").lower() == "no-cache" if disable_cache is None else disable_cache
accepted_response_codes = self.ACCEPTED_RESPONSE_CODES if accepted_response_codes is None else accepted_response_codes
current_kwargs = copy.copy(locals())
parsed_url = urlparse(url) parsed_url = urlparse(url)
headers = self._update_headers( headers = self._update_headers(
headers=headers, headers=headers,
refer_from_origin=refer_from_origin, refer_from_origin=refer_from_origin,
url=parsed_url url=parsed_url
) )
disable_cache = headers.get("Cache-Control") == "no-cache" or kwargs.get("disable_cache", False)
if name != "" and not disable_cache: if name != "" and not disable_cache:
cached = self.cache.get(name) cached = self.cache.get(name)
@ -225,105 +230,58 @@ class Connection:
current_kwargs["try_count"] = current_kwargs.get("try_count", 0) + 1 current_kwargs["try_count"] = current_kwargs.get("try_count", 0) + 1
return self.request(**current_kwargs) return self.request(**current_kwargs)
def get( @merge_args(request)
self, def get(self, *args, **kwargs) -> Optional[requests.Response]:
url: str, return self.request(
refer_from_origin: bool = True, *args,
stream: bool = False,
accepted_response_codes: set = None,
timeout: float = None,
headers: dict = None,
raw_url: bool = False,
**kwargs
) -> Optional[requests.Response]:
if accepted_response_codes is None:
accepted_response_codes = self.ACCEPTED_RESPONSE_CODES
r = self.request(
method="GET", method="GET",
try_count=0,
accepted_response_codes=accepted_response_codes,
url=url,
timeout=timeout,
headers=headers,
raw_url=raw_url,
refer_from_origin=refer_from_origin,
stream=stream,
**kwargs **kwargs
) )
if r is None:
self.LOGGER.warning(f"Max attempts ({self.TRIES}) exceeded for: GET:{url}")
return r
@merge_args(request)
def post( def post(
self, self,
url: str, *args,
json: dict = None, json: dict = None,
refer_from_origin: bool = True,
stream: bool = False,
accepted_response_codes: set = None,
timeout: float = None,
headers: dict = None,
raw_url: bool = False,
**kwargs **kwargs
) -> Optional[requests.Response]: ) -> Optional[requests.Response]:
r = self.request( r = self.request(
*args,
method="POST", method="POST",
try_count=0,
accepted_response_codes=accepted_response_codes or self.ACCEPTED_RESPONSE_CODES,
url=url,
timeout=timeout,
headers=headers,
refer_from_origin=refer_from_origin,
raw_url=raw_url,
json=json, json=json,
stream=stream,
**kwargs **kwargs
) )
if r is None: if r is None:
self.LOGGER.warning(f"Max attempts ({self.TRIES}) exceeded for: GET:{url}")
self.LOGGER.warning(f"payload: {json}") self.LOGGER.warning(f"payload: {json}")
return r return r
@merge_args(request)
def stream_into( def stream_into(
self, self,
url: str, url: str,
target: Target, target: Target,
description: str = "download", name: str = "download",
refer_from_origin: bool = True,
accepted_response_codes: set = None,
timeout: float = None,
headers: dict = None,
raw_url: bool = False,
chunk_size: int = main_settings["chunk_size"], chunk_size: int = main_settings["chunk_size"],
try_count: int = 0,
progress: int = 0, progress: int = 0,
method: str = "GET",
**kwargs **kwargs
) -> DownloadResult: ) -> DownloadResult:
stream_kwargs = copy.copy(locals())
if progress > 0: if progress > 0:
if headers is None: headers = dict() if headers is None else headers
headers = dict()
headers["Range"] = f"bytes={target.size}-" headers["Range"] = f"bytes={target.size}-"
if accepted_response_codes is None:
accepted_response_codes = self.ACCEPTED_RESPONSE_CODES
r = self.request( r = self.request(
method="GET",
try_count=0,
accepted_response_codes=accepted_response_codes,
url=url, url=url,
timeout=timeout, name=name,
headers=headers, chunk_size=chunk_size,
raw_url=raw_url, method=method,
refer_from_origin=refer_from_origin,
stream=True,
**kwargs **kwargs
) )
if r is None: if r is None:
return DownloadResult(error_message=f"Could not establish connection to: {url}") return DownloadResult(error_message=f"Could not establish a stream from: {url}")
target.create_path() target.create_path()
total_size = int(r.headers.get('content-length')) total_size = int(r.headers.get('content-length'))
@ -337,8 +295,7 @@ class Connection:
> The internationally recommended unit symbol for the kilobyte is kB. > The internationally recommended unit symbol for the kilobyte is kB.
""" """
with tqdm(total=total_size - target.size, unit='B', unit_scale=True, unit_divisor=1024, with tqdm(total=total_size - target.size, unit='B', unit_scale=True, unit_divisor=1024, desc=name) as t:
desc=description) as t:
try: try:
for chunk in r.iter_content(chunk_size=chunk_size): for chunk in r.iter_content(chunk_size=chunk_size):
size = f.write(chunk) size = f.write(chunk)
@ -348,8 +305,7 @@ class Connection:
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
if try_count >= self.TRIES: if try_count >= self.TRIES:
self.LOGGER.warning(f"Stream timed out at \"{url}\": to many retries, aborting.") self.LOGGER.warning(f"Stream timed out at \"{url}\": to many retries, aborting.")
return DownloadResult( return DownloadResult(error_message=f"Stream timed out from {url}, reducing the chunk_size might help.")
error_message=f"Stream timed out from {url}, reducing the chunksize might help.")
self.LOGGER.warning(f"Stream timed out at \"{url}\": ({try_count}-{self.TRIES})") self.LOGGER.warning(f"Stream timed out at \"{url}\": ({try_count}-{self.TRIES})")
retry = True retry = True
@ -360,19 +316,6 @@ class Connection:
if retry: if retry:
self.LOGGER.warning(f"Retrying stream...") self.LOGGER.warning(f"Retrying stream...")
accepted_response_codes.add(206) accepted_response_codes.add(206)
return self.stream_into( return self.stream_into(**stream_kwargs)
url=url,
target=target,
description=description,
try_count=try_count + 1,
progress=progress,
accepted_response_codes=accepted_response_codes,
timeout=timeout,
headers=headers,
raw_url=raw_url,
refer_from_origin=refer_from_origin,
chunk_size=chunk_size,
**kwargs
)
return DownloadResult() return DownloadResult()