From a8bcb386184182ec99850bcad045ab8fdb5d8150 Mon Sep 17 00:00:00 2001 From: Hazel Noack Date: Wed, 11 Jun 2025 14:20:24 +0200 Subject: [PATCH] implemented setting cache directory properly --- python_requests/__init__.py | 3 ++- python_requests/__main__.py | 11 ++++++----- python_requests/cache.py | 31 ++++++++++++++----------------- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/python_requests/__init__.py b/python_requests/__init__.py index 8f5c13c..d0f841c 100644 --- a/python_requests/__init__.py +++ b/python_requests/__init__.py @@ -1,6 +1,6 @@ import pathlib from .connections import Connection, SilentConnection -from .cache import clean_cache, clear_cache, set_cache_directory +from .cache import clean_cache, clear_cache, set_cache_directory, get_cache_stats __name__ = "python_requests" @@ -10,5 +10,6 @@ __all__ = [ "SilentConnection", "clean_cache", "clear_cache", + "get_cache_stats", "set_cache_directory", ] diff --git a/python_requests/__main__.py b/python_requests/__main__.py index de61258..cd27ed6 100644 --- a/python_requests/__main__.py +++ b/python_requests/__main__.py @@ -1,8 +1,7 @@ import argparse import logging -from .connections import Connection, SilentConnection -from . import cache +from . import Connection, clean_cache, clear_cache, set_cache_directory, get_cache_stats, __folder__ def main(): @@ -49,11 +48,13 @@ def cli(): format='%(asctime)s - %(levelname)s - %(message)s' ) logging.debug("Debug logging enabled") + set_cache_directory("cache") else: logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) + set_cache_directory() if hasattr(args, 'func'): @@ -65,7 +66,7 @@ def cli(): def handle_show_cache(args): """Handle the show-cache command""" try: - file_count, db_count = cache.get_cache_stats() + file_count, db_count = get_cache_stats() logging.info(f"Cache Statistics:") logging.info(f" - Files in cache: {file_count}") logging.info(f" - Database entries: {db_count}") @@ -75,7 +76,7 @@ def handle_show_cache(args): def handle_clean_cache(args): """Handle the clean-cache command""" try: - files_deleted, entries_deleted = cache.clean_cache() + files_deleted, entries_deleted = clean_cache() logging.info(f"Cleaned cache:") logging.info(f" - Files deleted: {files_deleted}") logging.info(f" - Database entries removed: {entries_deleted}") @@ -88,7 +89,7 @@ def handle_clear_cache(args): # Confirm before clearing all cache confirm = input("Are you sure you want to clear ALL cache? This cannot be undone. [y/N]: ") if confirm.lower() == 'y': - files_deleted, entries_deleted = cache.clear_cache() + files_deleted, entries_deleted = clear_cache() logging.info(f"Cleared ALL cache:") logging.info(f" - Files deleted: {files_deleted}") logging.info(f" - Database entries removed: {entries_deleted}") diff --git a/python_requests/cache.py b/python_requests/cache.py index 962f979..b2f312a 100644 --- a/python_requests/cache.py +++ b/python_requests/cache.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional, Tuple, Union from codecs import encode from hashlib import sha1 from pathlib import Path @@ -9,22 +9,10 @@ from datetime import datetime, timedelta from . import __name__ -CACHE_DIRECTORY = f"/tmp/{__name__}" - - -def set_cache_directory(cache_directory: Optional[str] = None): - global CACHE_DIRECTORY - - if cache_directory is not None: - CACHE_DIRECTORY = cache_directory - - Path(CACHE_DIRECTORY).mkdir(exist_ok=True, parents=True) - - -# SQLite database file path +CACHE_DIRECTORY = Path(f"/tmp/{__name__}") DB_FILE = Path(CACHE_DIRECTORY, "cache_metadata.db") -# Initialize the database + def _init_db(): with sqlite3.connect(DB_FILE) as conn: conn.execute(""" @@ -35,8 +23,17 @@ def _init_db(): """) conn.commit() -# Initialize the database when module is imported -_init_db() + +def set_cache_directory(cache_directory: Optional[Union[str, Path]] = None): + global CACHE_DIRECTORY, DB_FILE + + if cache_directory is not None: + CACHE_DIRECTORY = cache_directory + DB_FILE = Path(CACHE_DIRECTORY, "cache_metadata.db") + _init_db() + + print(CACHE_DIRECTORY, DB_FILE) + Path(CACHE_DIRECTORY).mkdir(exist_ok=True, parents=True) def get_url_hash(url: str) -> str: