feat: added type annotation merging for conection function

This commit is contained in:
Hazel 2024-02-28 11:17:07 +01:00
parent 7f6db2781d
commit 7f2abdf572

View File

@ -4,10 +4,12 @@ import time
from typing import List, Dict, Optional, Set
from urllib.parse import urlparse, urlunsplit, ParseResult
import copy
import inspect
import requests
import responses
from tqdm import tqdm
import merge_args
from .cache import Cache
from .rotating import RotatingProxy
@ -125,32 +127,35 @@ class Connection:
def request(
self,
method: str,
try_count: int,
accepted_response_codes: set,
url: str,
timeout: float,
headers: Optional[dict],
timeout: float = None,
headers: Optional[dict] = None,
try_count: int = 0,
accepted_response_codes: set = None,
refer_from_origin: bool = True,
raw_url: bool = False,
sleep_after_404: float = None,
is_heartbeat: bool = False,
disable_cache: bool = None,
method: str = None,
name: str = "",
**kwargs
) -> Optional[requests.Response]:
current_kwargs = copy.copy(locals)
if method is None:
raise AttributeError("method is not set.")
method = method.upper()
disable_cache = headers.get("Cache-Control", "").lower() == "no-cache" if disable_cache is None else disable_cache
accepted_response_codes = self.ACCEPTED_RESPONSE_CODES if accepted_response_codes is None else accepted_response_codes
current_kwargs = copy.copy(locals())
parsed_url = urlparse(url)
headers = self._update_headers(
headers=headers,
refer_from_origin=refer_from_origin,
url=parsed_url
)
disable_cache = headers.get("Cache-Control") == "no-cache" or kwargs.get("disable_cache", False)
if name != "" and not disable_cache:
cached = self.cache.get(name)
@ -225,105 +230,58 @@ class Connection:
current_kwargs["try_count"] = current_kwargs.get("try_count", 0) + 1
return self.request(**current_kwargs)
def get(
self,
url: str,
refer_from_origin: bool = True,
stream: bool = False,
accepted_response_codes: set = None,
timeout: float = None,
headers: dict = None,
raw_url: bool = False,
**kwargs
) -> Optional[requests.Response]:
if accepted_response_codes is None:
accepted_response_codes = self.ACCEPTED_RESPONSE_CODES
r = self.request(
@merge_args(request)
def get(self, *args, **kwargs) -> Optional[requests.Response]:
return self.request(
*args,
method="GET",
try_count=0,
accepted_response_codes=accepted_response_codes,
url=url,
timeout=timeout,
headers=headers,
raw_url=raw_url,
refer_from_origin=refer_from_origin,
stream=stream,
**kwargs
)
if r is None:
self.LOGGER.warning(f"Max attempts ({self.TRIES}) exceeded for: GET:{url}")
return r
@merge_args(request)
def post(
self,
url: str,
*args,
json: dict = None,
refer_from_origin: bool = True,
stream: bool = False,
accepted_response_codes: set = None,
timeout: float = None,
headers: dict = None,
raw_url: bool = False,
**kwargs
) -> Optional[requests.Response]:
r = self.request(
*args,
method="POST",
try_count=0,
accepted_response_codes=accepted_response_codes or self.ACCEPTED_RESPONSE_CODES,
url=url,
timeout=timeout,
headers=headers,
refer_from_origin=refer_from_origin,
raw_url=raw_url,
json=json,
stream=stream,
**kwargs
)
if r is None:
self.LOGGER.warning(f"Max attempts ({self.TRIES}) exceeded for: GET:{url}")
self.LOGGER.warning(f"payload: {json}")
return r
@merge_args(request)
def stream_into(
self,
url: str,
target: Target,
description: str = "download",
refer_from_origin: bool = True,
accepted_response_codes: set = None,
timeout: float = None,
headers: dict = None,
raw_url: bool = False,
name: str = "download",
chunk_size: int = main_settings["chunk_size"],
try_count: int = 0,
progress: int = 0,
method: str = "GET",
**kwargs
) -> DownloadResult:
stream_kwargs = copy.copy(locals())
if progress > 0:
if headers is None:
headers = dict()
headers = dict() if headers is None else headers
headers["Range"] = f"bytes={target.size}-"
if accepted_response_codes is None:
accepted_response_codes = self.ACCEPTED_RESPONSE_CODES
r = self.request(
method="GET",
try_count=0,
accepted_response_codes=accepted_response_codes,
url=url,
timeout=timeout,
headers=headers,
raw_url=raw_url,
refer_from_origin=refer_from_origin,
stream=True,
name=name,
chunk_size=chunk_size,
method=method,
**kwargs
)
if r is None:
return DownloadResult(error_message=f"Could not establish connection to: {url}")
return DownloadResult(error_message=f"Could not establish a stream from: {url}")
target.create_path()
total_size = int(r.headers.get('content-length'))
@ -337,8 +295,7 @@ class Connection:
> The internationally recommended unit symbol for the kilobyte is kB.
"""
with tqdm(total=total_size - target.size, unit='B', unit_scale=True, unit_divisor=1024,
desc=description) as t:
with tqdm(total=total_size - target.size, unit='B', unit_scale=True, unit_divisor=1024, desc=name) as t:
try:
for chunk in r.iter_content(chunk_size=chunk_size):
size = f.write(chunk)
@ -348,8 +305,7 @@ class Connection:
except requests.exceptions.ConnectionError:
if try_count >= self.TRIES:
self.LOGGER.warning(f"Stream timed out at \"{url}\": to many retries, aborting.")
return DownloadResult(
error_message=f"Stream timed out from {url}, reducing the chunksize might help.")
return DownloadResult(error_message=f"Stream timed out from {url}, reducing the chunk_size might help.")
self.LOGGER.warning(f"Stream timed out at \"{url}\": ({try_count}-{self.TRIES})")
retry = True
@ -360,19 +316,6 @@ class Connection:
if retry:
self.LOGGER.warning(f"Retrying stream...")
accepted_response_codes.add(206)
return self.stream_into(
url=url,
target=target,
description=description,
try_count=try_count + 1,
progress=progress,
accepted_response_codes=accepted_response_codes,
timeout=timeout,
headers=headers,
raw_url=raw_url,
refer_from_origin=refer_from_origin,
chunk_size=chunk_size,
**kwargs
)
return self.stream_into(**stream_kwargs)
return DownloadResult()