diff --git a/.gitignore b/.gitignore index b38b31c..baefa3a 100644 --- a/.gitignore +++ b/.gitignore @@ -161,4 +161,6 @@ cython_debug/ #.idea/ .venv assets/* -*.pt \ No newline at end of file +*.pt + +big-lama \ No newline at end of file diff --git a/README.md b/README.md index 6772bc0..d825a80 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,16 @@ secure-pixelation This is the generative ai model to impaint the blacked out areas. ``` +# get the pretrained models mkdir -p ./big-lama wget https://huggingface.co/smartywu/big-lama/resolve/main/big-lama.zip unzip big-lama.zip -d ./big-lama +rm big-lama.zip + +# get the code to run the models +cd big-lama +git clone https://github.com/advimman/lama.git +pip install torch==2.2.0 torchvision==0.17.0 +cd lama +pip install -r requirements.txt ``` diff --git a/secure_pixelation/__main__.py b/secure_pixelation/__main__.py index 0db385f..8ccf921 100644 --- a/secure_pixelation/__main__.py +++ b/secure_pixelation/__main__.py @@ -5,4 +5,6 @@ from .pixelation_process import pixelate def cli(): print(f"Running secure_pixelation") - pixelate("assets/human_detection/test.png") + pixelate("assets/human_detection/test.png", generative_impaint=False) + pixelate("assets/human_detection/humans.png", generative_impaint=False) + pixelate("assets/human_detection/rev1.png", generative_impaint=False) diff --git a/secure_pixelation/pixelation_process.py b/secure_pixelation/pixelation_process.py index fa7a37c..7861bb8 100644 --- a/secure_pixelation/pixelation_process.py +++ b/secure_pixelation/pixelation_process.py @@ -2,6 +2,8 @@ from __future__ import annotations from typing import Optional from pathlib import Path +import subprocess +import sys import cv2 import numpy as np @@ -18,14 +20,18 @@ def blackout(raw_image: RawImage) -> np.ndarray: return image +def get_mask(raw_image: RawImage) -> np.ndarray: + mask = np.zeros(raw_image.image.shape[:2], dtype=np.uint8) + for (x, y, w, h) in raw_image.bounding_boxes: + mask[y:y+h, x:x+w] = 255 + + return mask + + def impaint(raw_image: RawImage, image: Optional[np.ndarray] = None) -> np.ndarray: image = image if image is not None else raw_image.get_image() - # Create a mask where blacked-out areas are marked as 255 (white) - mask = np.zeros(image.shape[:2], dtype=np.uint8) - - for (x, y, w, h) in raw_image.bounding_boxes: - mask[y:y+h, x:x+w] = 255 + mask = get_mask(raw_image) # Apply inpainting using the Telea method return cv2.inpaint(image, mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA) @@ -33,15 +39,41 @@ def impaint(raw_image: RawImage, image: Optional[np.ndarray] = None) -> np.ndarr def do_generative_impaint(raw_image: RawImage, image: Optional[np.ndarray] = None) -> np.ndarray: image = image if image is not None else raw_image.get_image() + lama_dict = raw_image.get_dir("steps") / "lama" + lama_dict.mkdir(exist_ok=True) + lama_dict_in = lama_dict / "in" + lama_dict_in.mkdir(exist_ok=True) + lama_dict_out = lama_dict / "out" + lama_dict_out.mkdir(exist_ok=True) - # Create a mask where blacked-out areas are marked as 255 (white) - mask = np.zeros(image.shape[:2], dtype=np.uint8) + cv2.imwrite(str(lama_dict_in / "image.png"), raw_image.image) + mask = get_mask(raw_image) + cv2.imwrite(str(lama_dict_in / "mask.png"), mask) - for (x, y, w, h) in raw_image.bounding_boxes: - mask[y:y+h, x:x+w] = 255 + # Run LaMa inference (adjust path if needed) + try: + subprocess.run([ + sys.executable, "big-lama/lama/bin/predict.py", + "model.path=big-lama/big-lama", + f"indir={str(lama_dict_in)}", + f"outdir={str(lama_dict_out)}" + ], check=True) + except subprocess.CalledProcessError as e: + print(f"Error running LaMa: {e}") + return image # fallback to original if it fails - # Apply inpainting using the Telea method - return cv2.inpaint(image, mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA) + # Load inpainted result + result_path = os.path.join(output_dir, "image.png") + if os.path.exists(result_path): + inpainted_image = cv2.imread(result_path) + else: + print("Inpainted result not found, returning original.") + inpainted_image = image + + # Cleanup + shutil.rmtree(base_dir) + + return inpainted_image