import requests import logging from urllib.parse import urlparse, urlunparse, parse_qs, urlencode import re from typing import Optional, List, Dict, Union, Any, Callable import json from functools import wraps from .exceptions import SponsorBlockError, SponsorBlockIdNotFoundError from .constants import Segment def error_handling(default: Any) -> Callable: def _decorator(func: Callable) -> Callable: @wraps(func) def _wrapper(self, *args, **kwargs) -> Any: nonlocal default try: return func(self, *args, **kwargs) except SponsorBlockError as e: if not self.silent: raise e if self._requests_logging_exists and isinstance(e, SponsorBlockConnectionError): return default self.logger.error(repr(e)) return default return _wrapper return _decorator class SponsorBlock: def __init__(self, session: requests.Session = None, base_url: str = "https://sponsor.ajay.app", silent: bool = False, _requests_logging_exists: bool = False): self.base_url: str = base_url self.session: requests.Session = session or requests.Session() self.silent: bool = silent self._requests_logging_exists: bool = _requests_logging_exists self.logger: logging.Logger = logging.Logger("SponsorBlock") def _get_video_id(self, video: str) -> str: if re.match(r"^[a-zA-Z0-9_-]{11}$", video): return video.strip() url = urlparse(url=video) if url.netloc == "youtu.be": return url.path[1:] type_frag_list = url.path.split("/") query_stuff = parse_qs(url.query) if "v" not in query_stuff: raise SponsorBlockIdNotFoundError("No video id found in the url") else: return query_stuff["v"][0] def _request(self, method: str, endpoint: str) -> Union[List, Dict]: error_message = "" url = self.base_url + endpoint r: requests.Response = None try: r = self.session.request(method=method, url=url) except requests.exceptions.Timeout: error_message = f"Request timed out at \"{url}\"" except requests.exceptions.ConnectionError: error_message = f"Couldn't connect to \"{url}\"" if error_message != "": raise exceptions.SponsorBlockConnectionError(error_message) if r.status_code == 400: self.logger.warning(f"{url} returned 400, meaning I did something wrong.") data = {} try: data = r.json() except json.JSONDecodeError: raise exceptions.SponsorBlockConnectionError(f"{r.content} is invalid json.") return data @error_handling(default=[]) def get_segments(self, video: str) -> List[Segment]: video_id = self._get_video_id(video) # build query parameters query = { "videoID": video_id } print(query) r = self._request(method="GET", endpoint="/api/skipSegments?" + urlencode(query)) return [constants.Segment(**d) for d in r]