diff --git a/python_requests/cache.py b/python_requests/cache.py index 85b7f53..b40df91 100644 --- a/python_requests/cache.py +++ b/python_requests/cache.py @@ -1,10 +1,11 @@ +from typing import Optional from codecs import encode from hashlib import sha1 from pathlib import Path import requests import pickle import sqlite3 -from datetime import datetime +from datetime import datetime, timedelta from . import CACHE_DIRECTORY @@ -18,7 +19,7 @@ def _init_db(): conn.execute(""" CREATE TABLE IF NOT EXISTS url_cache ( url_hash TEXT PRIMARY KEY, - last_updated TIMESTAMP + expires_at TIMESTAMP ) """) conn.commit() @@ -36,16 +37,37 @@ def get_url_file(url: str) -> Path: def has_cache(url: str) -> bool: - cache_exists = get_url_file(url).exists() - if not cache_exists: + url_hash = get_url_hash(url) + cache_file = get_url_file(url) + + if not cache_file.exists(): return False - # Check if the URL hash exists in the database - url_hash = get_url_hash(url) + # Check if the cache has expired with sqlite3.connect(DB_FILE) as conn: cursor = conn.cursor() - cursor.execute("SELECT 1 FROM url_cache WHERE url_hash = ?", (url_hash,)) - return cursor.fetchone() is not None + cursor.execute( + "SELECT expires_at FROM url_cache WHERE url_hash = ?", + (url_hash,) + ) + result = cursor.fetchone() + + if result is None: + return False # No expiration record exists + + expires_at = datetime.fromisoformat(result[0]) + if datetime.now() > expires_at: + # Cache expired, clean it up + cache_file.unlink(missing_ok=True) + cursor.execute( + "DELETE FROM url_cache WHERE url_hash = ?", + (url_hash,) + ) + conn.commit() + return False + + return True + def get_cache(url: str) -> requests.Response: @@ -53,9 +75,18 @@ def get_cache(url: str) -> requests.Response: return pickle.load(cache_file) -def write_cache(url: str, resp: requests.Response): +def write_cache( + url: str, + resp: requests.Response, + expires_after: Optional[timedelta] = None +): url_hash = get_url_hash(url) - current_time = datetime.now().isoformat() + + # Default expiration: 24 hours from now + if expires_after is None: + expires_after = timedelta(hours=1) + + expires_at = datetime.now() + expires_after # Write the cache file with get_url_file(url).open("wb") as url_file: @@ -64,7 +95,7 @@ def write_cache(url: str, resp: requests.Response): # Update the database with sqlite3.connect(DB_FILE) as conn: conn.execute( - "INSERT OR REPLACE INTO url_cache (url_hash, last_updated) VALUES (?, ?)", - (url_hash, current_time) + "INSERT OR REPLACE INTO url_cache (url_hash, expires_at) VALUES (?, ?)", + (url_hash, expires_at.isoformat()) ) conn.commit()