from __future__ import annotations import copy import inspect import logging import threading import time from typing import TYPE_CHECKING, Dict, List, Optional, Set from urllib.parse import ParseResult, urlparse, urlunsplit import requests import responses from tqdm import tqdm from .cache import Cache from .rotating import RotatingProxy if TYPE_CHECKING: from ..objects import Target from ..utils import request_trace from ..utils.config import main_settings from ..utils.hacking import merge_args from ..utils.string_processing import shorten_display_url from ..utils.support_classes.download_result import DownloadResult class Connection: def __init__( self, host: str = None, proxies: List[dict] = None, tries: int = (len(main_settings["proxies"]) + 1) * main_settings["tries_per_proxy"], timeout: int = 7, logger: logging.Logger = logging.getLogger("connection"), header_values: Dict[str, str] = None, accepted_response_codes: Set[int] = None, semantic_not_found: bool = True, sleep_after_404: float = 0.0, heartbeat_interval=0, module: str = "general", cache_expiring_duration: float = 10 ): if proxies is None: proxies = main_settings["proxies"] self.cache: Cache = Cache(module=module, logger=logger) self.cache_expiring_duration = cache_expiring_duration self.HEADER_VALUES = dict() if header_values is None else header_values self.LOGGER = logger self.HOST = host if host is None else urlparse(host) self.TRIES = tries self.TIMEOUT = timeout self.rotating_proxy = RotatingProxy(proxy_list=proxies) self.ACCEPTED_RESPONSE_CODES = accepted_response_codes or {200} self.SEMANTIC_NOT_FOUND = semantic_not_found self.sleep_after_404 = sleep_after_404 self.session = requests.Session() self.session.headers = self.get_header(**self.HEADER_VALUES) self.session.proxies = self.rotating_proxy.current_proxy self.heartbeat_thread = None self.heartbeat_interval = heartbeat_interval self.lock: bool = False def start_heartbeat(self): if self.heartbeat_interval <= 0: self.LOGGER.warning(f"Can't start a heartbeat with {self.heartbeat_interval}s in between.") self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop, args=(self.heartbeat_interval,), daemon=True) self.heartbeat_thread.start() def heartbeat_failed(self): self.LOGGER.warning(f"The hearth couldn't beat.") def heartbeat(self): # Your code to send heartbeat requests goes here raise NotImplementedError("please implement the heartbeat function.") def _heartbeat_loop(self, interval: float): def heartbeat_wrapper(): self.LOGGER.debug(f"The hearth is beating.") self.heartbeat() while True: heartbeat_wrapper() time.sleep(interval) def base_url(self, url: ParseResult = None): if url is None and self.HOST is not None: url = self.HOST return urlunsplit((url.scheme, url.netloc, "", "", "")) def get_header(self, **header_values) -> Dict[str, str]: headers = { "user-agent": main_settings["user_agent"], "User-Agent": main_settings["user_agent"], "Connection": "keep-alive", "Accept-Language": main_settings["language"], } if self.HOST is not None: # headers["Host"] = self.HOST.netloc headers["Referer"] = self.base_url(url=self.HOST) headers.update(header_values) return headers def rotate(self): self.session.proxies = self.rotating_proxy.rotate() def _update_headers( self, headers: Optional[dict], refer_from_origin: bool, url: ParseResult ) -> Dict[str, str]: headers = self.get_header(**(headers or {})) if not refer_from_origin: headers["Referer"] = self.base_url(url=url) return headers def save(self, r: requests.Response, name: str, error: bool = False, no_update_if_valid_exists: bool = False, **kwargs): n_kwargs = {} if error: n_kwargs["module"] = "failed_requests" if self.cache.get(name) is not None and no_update_if_valid_exists: return self.cache.set(r.content, name, expires_in=kwargs.get("expires_in", self.cache_expiring_duration), additional_info={ "encoding": r.encoding, }, **n_kwargs) def request( self, url: str, timeout: float = None, headers: Optional[dict] = None, try_count: int = 0, accepted_response_codes: set = None, refer_from_origin: bool = True, raw_url: bool = False, raw_headers: bool = False, sleep_after_404: float = None, is_heartbeat: bool = False, disable_cache: bool = None, enable_cache_readonly: bool = False, method: str = None, name: str = "", exclude_headers: List[str] = None, **kwargs ) -> Optional[requests.Response]: if method is None: raise AttributeError("method is not set.") method = method.upper() headers = dict() if headers is None else headers disable_cache = (headers.get("Cache-Control", "").lower() == "no-cache" if disable_cache is None else disable_cache) or kwargs.get("stream", False) accepted_response_codes = self.ACCEPTED_RESPONSE_CODES if accepted_response_codes is None else accepted_response_codes current_kwargs = copy.copy(locals()) current_kwargs.pop("kwargs") current_kwargs.update(**kwargs) parsed_url = urlparse(url) trace_string = f"{method} {shorten_display_url(url)} \t{'[stream]' if kwargs.get('stream', False) else ''}" if not raw_headers: _headers = copy.copy(self.HEADER_VALUES) _headers.update(headers) headers = self._update_headers( headers=_headers, refer_from_origin=refer_from_origin, url=parsed_url ) else: headers = headers or {} request_url = parsed_url.geturl() if not raw_url else url if name != "" and (not disable_cache or enable_cache_readonly): cached = self.cache.get(name) if cached is not None: request_trace(f"{trace_string}\t[cached]") with responses.RequestsMock() as resp: additional_info = cached.attribute.additional_info body = cached.content if additional_info.get("encoding", None) is not None: body = body.decode(additional_info["encoding"]) resp.add( method=method, url=request_url, body=body, ) return requests.request(method=method, url=url, timeout=timeout, headers=headers, **kwargs) if sleep_after_404 is None: sleep_after_404 = self.sleep_after_404 if try_count >= self.TRIES: return if timeout is None: timeout = self.TIMEOUT for header in exclude_headers or []: if header in headers: del headers[header] if try_count <= 0: request_trace(trace_string) r = None connection_failed = False try: if self.lock: self.LOGGER.info(f"Waiting for the heartbeat to finish.") while self.lock and not is_heartbeat: pass self.lock = True r: requests.Response = self.session.request(method=method, url=url, timeout=timeout, headers=headers, **kwargs) if r.status_code in accepted_response_codes: if not disable_cache: self.save(r, name, **kwargs) return r # the server rejected the request, or the internet is lacking except requests.exceptions.Timeout: self.LOGGER.warning(f"Request timed out at \"{request_url}\": ({try_count}-{self.TRIES})") connection_failed = True except requests.exceptions.ConnectionError: self.LOGGER.warning(f"Couldn't connect to \"{request_url}\": ({try_count}-{self.TRIES})") connection_failed = True # this is important for thread safety finally: self.lock = False if r is None: self.LOGGER.warning(f"{parsed_url.netloc} didn't respond at {url}. ({try_count}-{self.TRIES})") self.LOGGER.debug("request headers:\n\t"+ "\n\t".join(f"{k}\t=\t{v}" for k, v in headers.items())) else: self.LOGGER.warning(f"{parsed_url.netloc} responded wit {r.status_code} at {url}. ({try_count}-{self.TRIES})") self.LOGGER.debug("request headers:\n\t"+ "\n\t".join(f"{k}\t=\t{v}" for k, v in r.request.headers.items())) self.LOGGER.debug("response headers:\n\t"+ "\n\t".join(f"{k}\t=\t{v}" for k, v in r.headers.items())) self.LOGGER.debug(r.content) if name != "": self.save(r, name, error=True, **kwargs) if self.SEMANTIC_NOT_FOUND and r.status_code == 404: return None if sleep_after_404 != 0: self.LOGGER.warning(f"Waiting for {sleep_after_404} seconds.") time.sleep(sleep_after_404) self.rotate() current_kwargs["try_count"] = current_kwargs.get("try_count", 0) + 1 return Connection.request(**current_kwargs) @merge_args(request) def get(self, *args, **kwargs) -> Optional[requests.Response]: return self.request( *args, method="GET", **kwargs ) @merge_args(request) def post( self, *args, json: dict = None, **kwargs ) -> Optional[requests.Response]: r = self.request( *args, method="POST", json=json, **kwargs ) if r is None: self.LOGGER.warning(f"payload: {json}") return r @merge_args(request) def stream_into( self, url: str, target: Target, name: str = "download", chunk_size: int = main_settings["chunk_size"], progress: int = 0, method: str = "GET", try_count: int = 0, accepted_response_codes: set = None, **kwargs ) -> DownloadResult: accepted_response_codes = self.ACCEPTED_RESPONSE_CODES if accepted_response_codes is None else accepted_response_codes stream_kwargs = copy.copy(locals()) stream_kwargs.update(stream_kwargs.pop("kwargs")) if "description" in kwargs: name = kwargs.pop("description") if progress > 0: headers = kwargs.get("headers", dict()) headers["Range"] = f"bytes={target.size}-" r = self.request( url=url, name=name, method=method, stream=True, accepted_response_codes=accepted_response_codes, **kwargs ) if r is None: return DownloadResult(error_message=f"Could not establish a stream from: {url}") target.create_path() total_size = int(r.headers.get('content-length', r.headers.get('Content-Length', chunk_size))) progress = 0 retry = False with target.open("ab") as f: """ https://en.wikipedia.org/wiki/Kilobyte > The internationally recommended unit symbol for the kilobyte is kB. """ with tqdm(total=total_size, initial=target.size, unit='B', unit_scale=True, unit_divisor=1024, desc=name) as t: try: for chunk in r.iter_content(chunk_size=chunk_size): size = f.write(chunk) progress += size t.update(size) except (requests.exceptions.ConnectionError, requests.exceptions.Timeout, requests.exceptions.ChunkedEncodingError): if try_count >= self.TRIES: self.LOGGER.warning(f"Stream timed out at \"{url}\": to many retries, aborting.") return DownloadResult(error_message=f"Stream timed out from {url}, reducing the chunk_size might help.") self.LOGGER.warning(f"Stream timed out at \"{url}\": ({try_count}-{self.TRIES})") retry = True try_count += 1 if total_size > progress: retry = True if retry: self.LOGGER.warning(f"Retrying stream...") accepted_response_codes.add(206) stream_kwargs["progress"] = progress return Connection.stream_into(**stream_kwargs) return DownloadResult()