From 5baefdcc6ff85f3c9749043a07e2239476029e1e Mon Sep 17 00:00:00 2001 From: Lars Noack Date: Thu, 24 Apr 2025 16:50:37 +0200 Subject: [PATCH] feat: implemented correct lama bindings --- README.md | 2 +- pyproject.toml | 7 +++ secure_pixelation/pixelation_process.py | 34 +--------- secure_pixelation/simple_lama_bindings.py | 76 +++++++++++++++++++++++ 4 files changed, 87 insertions(+), 32 deletions(-) create mode 100644 secure_pixelation/simple_lama_bindings.py diff --git a/README.md b/README.md index 582404e..a161ebd 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ I first realized that a normal mosaic algorithm isn't safe AT ALL seeing this pr ```bash # Step 1: Create and activate virtual environment -python3 -m venv .venv +python3.8 -m venv .venv source .venv/bin/activate # Step 2: Install the local Python program add the -e flag for development diff --git a/pyproject.toml b/pyproject.toml index e2b1ea3..a7f49b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,14 @@ name = "secure_pixelation" version = "0.0.0" dependencies = [ + "torch==2.1.2", + "torchvision==0.16.2", + "opencv_python~=4.11.0.86", + "numpy<2.0.0", + "hf_transfer==0.1.8", + "huggingface_hub==0.25.1", + "ultralytics~=8.3.114", ] authors = [] diff --git a/secure_pixelation/pixelation_process.py b/secure_pixelation/pixelation_process.py index 3790569..8f38394 100644 --- a/secure_pixelation/pixelation_process.py +++ b/secure_pixelation/pixelation_process.py @@ -10,6 +10,7 @@ import cv2 import numpy as np from .data_classes import RawImage +from .simple_lama_bindings import SimpleLama # https://github.com/okaris/simple-lama/tree/main def blackout(raw_image: RawImage) -> np.ndarray: @@ -40,39 +41,10 @@ def quick_impaint(raw_image: RawImage, image: Optional[np.ndarray] = None) -> np def do_generative_impaint(raw_image: RawImage, image: Optional[np.ndarray] = None) -> np.ndarray: image = image if image is not None else raw_image.get_image() - lama_dict = raw_image.get_dir("steps") / "lama" - lama_dict.mkdir(exist_ok=True) - lama_dict_in = lama_dict / "in" - lama_dict_in.mkdir(exist_ok=True) - lama_dict_out = lama_dict / "out" - lama_dict_out.mkdir(exist_ok=True) - - cv2.imwrite(str(lama_dict_in / "image.png"), raw_image.image) mask = get_mask(raw_image) - cv2.imwrite(str(lama_dict_in / "mask.png"), mask) - - # Run LaMa inference (adjust path if needed) - try: - pwd = os.getcwd() - subprocess.run([ - sys.executable, "lama/bin/predict.py", - f"model.path={pwd}/lama/models/big-lama", - f"indir={pwd}/{str(lama_dict_in)}", - f"outdir={pwd}/{str(lama_dict_out)}" - ], check=True) - except subprocess.CalledProcessError as e: - print(f"Error running LaMa: {e}") - print("falling back to non generative inpaint") - return quick_impaint(raw_image=raw_image, image=image) - - # Load inpainted result - result_path = lama_dict_out / "image.png" - if result_path.exists(): - return cv2.imread(str(result_path)) - else: - print("Inpainted result not found, falling back to non generative inpaint") - return quick_impaint(raw_image=raw_image, image=image) + lama = SimpleLama() + return lama(image=image, mask=mask) def pixelate_regions(raw_image: RawImage, image: Optional[np.ndarray] = None, pixel_size: int = 10) -> np.ndarray: diff --git a/secure_pixelation/simple_lama_bindings.py b/secure_pixelation/simple_lama_bindings.py new file mode 100644 index 0000000..654aa12 --- /dev/null +++ b/secure_pixelation/simple_lama_bindings.py @@ -0,0 +1,76 @@ +import os +from typing import Optional + +import torch +import cv2 +import numpy as np +from huggingface_hub import hf_hub_download + +# https://github.com/okaris/simple-lama/blob/main/src/simple_lama/simple_lama.py + + +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" + +def prepare_img_and_mask(image: np.ndarray, mask: np.ndarray, device: torch.device, pad_out_to_modulo: int = 8, scale_factor: float = 1) -> tuple[torch.Tensor, torch.Tensor]: + def get_image(img: np.ndarray): + img = img.copy() + + if img.ndim == 3: + img = np.transpose(img, (2, 0, 1)) # chw + elif img.ndim == 2: + img = img[np.newaxis, ...] + + return img.astype(np.float32) / 255 + + def scale_image(img: np.ndarray, factor: float, interpolation=cv2.INTER_AREA) -> np.ndarray: + if img.shape[0] == 1: + img = img[0] + else: + img = np.transpose(img, (1, 2, 0)) + + img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation) + return img[None, ...] if img.ndim == 2 else np.transpose(img, (2, 0, 1)) + + def pad_img_to_modulo(img, mod): + channels, height, width = img.shape + out_height = height if height % mod == 0 else ((height // mod + 1) * mod) + out_width = width if width % mod == 0 else ((width // mod + 1) * mod) + return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode="symmetric") + + out_image = get_image(image) + out_mask = get_image(mask) + + if scale_factor != 1: + out_image = scale_image(out_image, scale_factor) + out_mask = scale_image(out_mask, scale_factor, interpolation=cv2.INTER_NEAREST) + + if pad_out_to_modulo > 1: + out_image = pad_img_to_modulo(out_image, pad_out_to_modulo) + out_mask = pad_img_to_modulo(out_mask, pad_out_to_modulo) + + out_image = torch.from_numpy(out_image).unsqueeze(0).to(device) + out_mask = torch.from_numpy(out_mask).unsqueeze(0).to(device) + + return out_image, (out_mask > 0) * 1 + + +class SimpleLama: + """ + lama = SimpleLama() + result = lama(image, mask) + """ + + def __init__(self, device=None): + self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {self.device}") + + model_path = hf_hub_download("okaris/simple-lama", "big-lama.pt") + self.model = torch.jit.load(model_path, map_location=self.device).eval() + + def __call__(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray: + image, mask = prepare_img_and_mask(image, mask, self.device) + + with torch.inference_mode(): + inpainted = self.model(image, mask) + cur_res = inpainted[0].permute(1, 2, 0).detach().cpu().numpy() + return np.clip(cur_res * 255, 0, 255).astype(np.uint8)