secure_pixelation/secure_pixelation/simple_lama_bindings.py

77 lines
2.8 KiB
Python

import os
from typing import Optional
import torch
import cv2
import numpy as np
from huggingface_hub import hf_hub_download
# https://github.com/okaris/simple-lama/blob/main/src/simple_lama/simple_lama.py
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
def prepare_img_and_mask(image: np.ndarray, mask: np.ndarray, device: torch.device, pad_out_to_modulo: int = 8, scale_factor: float = 1) -> tuple[torch.Tensor, torch.Tensor]:
def get_image(img: np.ndarray):
img = img.copy()
if img.ndim == 3:
img = np.transpose(img, (2, 0, 1)) # chw
elif img.ndim == 2:
img = img[np.newaxis, ...]
return img.astype(np.float32) / 255
def scale_image(img: np.ndarray, factor: float, interpolation=cv2.INTER_AREA) -> np.ndarray:
if img.shape[0] == 1:
img = img[0]
else:
img = np.transpose(img, (1, 2, 0))
img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation)
return img[None, ...] if img.ndim == 2 else np.transpose(img, (2, 0, 1))
def pad_img_to_modulo(img, mod):
channels, height, width = img.shape
out_height = height if height % mod == 0 else ((height // mod + 1) * mod)
out_width = width if width % mod == 0 else ((width // mod + 1) * mod)
return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode="symmetric")
out_image = get_image(image)
out_mask = get_image(mask)
if scale_factor != 1:
out_image = scale_image(out_image, scale_factor)
out_mask = scale_image(out_mask, scale_factor, interpolation=cv2.INTER_NEAREST)
if pad_out_to_modulo > 1:
out_image = pad_img_to_modulo(out_image, pad_out_to_modulo)
out_mask = pad_img_to_modulo(out_mask, pad_out_to_modulo)
out_image = torch.from_numpy(out_image).unsqueeze(0).to(device)
out_mask = torch.from_numpy(out_mask).unsqueeze(0).to(device)
return out_image, (out_mask > 0) * 1
class SimpleLama:
"""
lama = SimpleLama()
result = lama(image, mask)
"""
def __init__(self, device=None):
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
model_path = hf_hub_download("okaris/simple-lama", "big-lama.pt")
self.model = torch.jit.load(model_path, map_location=self.device).eval()
def __call__(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
image, mask = prepare_img_and_mask(image, mask, self.device)
with torch.inference_mode():
inpainted = self.model(image, mask)
cur_res = inpainted[0].permute(1, 2, 0).detach().cpu().numpy()
return np.clip(cur_res * 255, 0, 255).astype(np.uint8)