generated from Hazel/python-project
feat: implemented correct lama bindings
This commit is contained in:
parent
eb00e869fc
commit
5baefdcc6f
@ -12,7 +12,7 @@ I first realized that a normal mosaic algorithm isn't safe AT ALL seeing this pr
|
||||
|
||||
```bash
|
||||
# Step 1: Create and activate virtual environment
|
||||
python3 -m venv .venv
|
||||
python3.8 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
|
||||
# Step 2: Install the local Python program add the -e flag for development
|
||||
|
@ -2,7 +2,14 @@
|
||||
name = "secure_pixelation"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"torch==2.1.2",
|
||||
"torchvision==0.16.2",
|
||||
|
||||
"opencv_python~=4.11.0.86",
|
||||
"numpy<2.0.0",
|
||||
"hf_transfer==0.1.8",
|
||||
"huggingface_hub==0.25.1",
|
||||
|
||||
"ultralytics~=8.3.114",
|
||||
]
|
||||
authors = []
|
||||
|
@ -10,6 +10,7 @@ import cv2
|
||||
import numpy as np
|
||||
|
||||
from .data_classes import RawImage
|
||||
from .simple_lama_bindings import SimpleLama
|
||||
# https://github.com/okaris/simple-lama/tree/main
|
||||
|
||||
def blackout(raw_image: RawImage) -> np.ndarray:
|
||||
@ -40,39 +41,10 @@ def quick_impaint(raw_image: RawImage, image: Optional[np.ndarray] = None) -> np
|
||||
|
||||
def do_generative_impaint(raw_image: RawImage, image: Optional[np.ndarray] = None) -> np.ndarray:
|
||||
image = image if image is not None else raw_image.get_image()
|
||||
lama_dict = raw_image.get_dir("steps") / "lama"
|
||||
lama_dict.mkdir(exist_ok=True)
|
||||
lama_dict_in = lama_dict / "in"
|
||||
lama_dict_in.mkdir(exist_ok=True)
|
||||
lama_dict_out = lama_dict / "out"
|
||||
lama_dict_out.mkdir(exist_ok=True)
|
||||
|
||||
cv2.imwrite(str(lama_dict_in / "image.png"), raw_image.image)
|
||||
mask = get_mask(raw_image)
|
||||
cv2.imwrite(str(lama_dict_in / "mask.png"), mask)
|
||||
|
||||
# Run LaMa inference (adjust path if needed)
|
||||
try:
|
||||
pwd = os.getcwd()
|
||||
subprocess.run([
|
||||
sys.executable, "lama/bin/predict.py",
|
||||
f"model.path={pwd}/lama/models/big-lama",
|
||||
f"indir={pwd}/{str(lama_dict_in)}",
|
||||
f"outdir={pwd}/{str(lama_dict_out)}"
|
||||
], check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error running LaMa: {e}")
|
||||
print("falling back to non generative inpaint")
|
||||
return quick_impaint(raw_image=raw_image, image=image)
|
||||
|
||||
# Load inpainted result
|
||||
result_path = lama_dict_out / "image.png"
|
||||
if result_path.exists():
|
||||
return cv2.imread(str(result_path))
|
||||
else:
|
||||
print("Inpainted result not found, falling back to non generative inpaint")
|
||||
return quick_impaint(raw_image=raw_image, image=image)
|
||||
|
||||
lama = SimpleLama()
|
||||
return lama(image=image, mask=mask)
|
||||
|
||||
|
||||
def pixelate_regions(raw_image: RawImage, image: Optional[np.ndarray] = None, pixel_size: int = 10) -> np.ndarray:
|
||||
|
76
secure_pixelation/simple_lama_bindings.py
Normal file
76
secure_pixelation/simple_lama_bindings.py
Normal file
@ -0,0 +1,76 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import cv2
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
# https://github.com/okaris/simple-lama/blob/main/src/simple_lama/simple_lama.py
|
||||
|
||||
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
|
||||
def prepare_img_and_mask(image: np.ndarray, mask: np.ndarray, device: torch.device, pad_out_to_modulo: int = 8, scale_factor: float = 1) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def get_image(img: np.ndarray):
|
||||
img = img.copy()
|
||||
|
||||
if img.ndim == 3:
|
||||
img = np.transpose(img, (2, 0, 1)) # chw
|
||||
elif img.ndim == 2:
|
||||
img = img[np.newaxis, ...]
|
||||
|
||||
return img.astype(np.float32) / 255
|
||||
|
||||
def scale_image(img: np.ndarray, factor: float, interpolation=cv2.INTER_AREA) -> np.ndarray:
|
||||
if img.shape[0] == 1:
|
||||
img = img[0]
|
||||
else:
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
|
||||
img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation)
|
||||
return img[None, ...] if img.ndim == 2 else np.transpose(img, (2, 0, 1))
|
||||
|
||||
def pad_img_to_modulo(img, mod):
|
||||
channels, height, width = img.shape
|
||||
out_height = height if height % mod == 0 else ((height // mod + 1) * mod)
|
||||
out_width = width if width % mod == 0 else ((width // mod + 1) * mod)
|
||||
return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode="symmetric")
|
||||
|
||||
out_image = get_image(image)
|
||||
out_mask = get_image(mask)
|
||||
|
||||
if scale_factor != 1:
|
||||
out_image = scale_image(out_image, scale_factor)
|
||||
out_mask = scale_image(out_mask, scale_factor, interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
if pad_out_to_modulo > 1:
|
||||
out_image = pad_img_to_modulo(out_image, pad_out_to_modulo)
|
||||
out_mask = pad_img_to_modulo(out_mask, pad_out_to_modulo)
|
||||
|
||||
out_image = torch.from_numpy(out_image).unsqueeze(0).to(device)
|
||||
out_mask = torch.from_numpy(out_mask).unsqueeze(0).to(device)
|
||||
|
||||
return out_image, (out_mask > 0) * 1
|
||||
|
||||
|
||||
class SimpleLama:
|
||||
"""
|
||||
lama = SimpleLama()
|
||||
result = lama(image, mask)
|
||||
"""
|
||||
|
||||
def __init__(self, device=None):
|
||||
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {self.device}")
|
||||
|
||||
model_path = hf_hub_download("okaris/simple-lama", "big-lama.pt")
|
||||
self.model = torch.jit.load(model_path, map_location=self.device).eval()
|
||||
|
||||
def __call__(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
||||
image, mask = prepare_img_and_mask(image, mask, self.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
inpainted = self.model(image, mask)
|
||||
cur_res = inpainted[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||
return np.clip(cur_res * 255, 0, 255).astype(np.uint8)
|
Loading…
x
Reference in New Issue
Block a user