generated from Hazel/python-project
Compare commits
30 Commits
ff2088c1d0
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
54a2138746 | ||
|
|
37a5da37b0 | ||
|
|
edd8096030 | ||
|
|
d576f9979c | ||
|
|
f6a774a01f | ||
|
|
6126e675f1 | ||
|
|
df4b949dd2 | ||
|
|
6101a8d5e4 | ||
|
|
b4c7512a73 | ||
|
|
ed650dcc5d | ||
|
|
aaa706264d | ||
|
|
8d6eecaf78 | ||
|
|
f23fd1cdb3 | ||
|
|
2467b4788f | ||
| 8bd512a0a7 | |||
| 8fc56b887d | |||
| edad12841f | |||
| 5baefdcc6f | |||
| eb00e869fc | |||
| 94b641cbd6 | |||
| 061cc20046 | |||
| 8753e1e05f | |||
| 529e1af517 | |||
| ad38eef03b | |||
| 678aeab7a5 | |||
| 180b41ffa4 | |||
| 88180d035c | |||
| b88f9c22a3 | |||
| cb9e594837 | |||
| 0895256dc4 |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -161,4 +161,6 @@ cython_debug/
|
|||||||
#.idea/
|
#.idea/
|
||||||
.venv
|
.venv
|
||||||
assets/*
|
assets/*
|
||||||
*.pt
|
*.pt
|
||||||
|
|
||||||
|
big-lama
|
||||||
22
README.md
22
README.md
@@ -12,8 +12,8 @@ I first realized that a normal mosaic algorithm isn't safe AT ALL seeing this pr
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Step 1: Create and activate virtual environment
|
# Step 1: Create and activate virtual environment
|
||||||
python3 -m venv .venv
|
python3.8 -m venv .venv
|
||||||
source venv/bin/activate
|
source .venv/bin/activate
|
||||||
|
|
||||||
# Step 2: Install the local Python program add the -e flag for development
|
# Step 2: Install the local Python program add the -e flag for development
|
||||||
pip install .
|
pip install .
|
||||||
@@ -21,3 +21,21 @@ pip install .
|
|||||||
# Step 3: Run the secure-pixelation command
|
# Step 3: Run the secure-pixelation command
|
||||||
secure-pixelation
|
secure-pixelation
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Setup LaMa
|
||||||
|
|
||||||
|
This is the generative ai model to impaint the blacked out areas.
|
||||||
|
|
||||||
|
```
|
||||||
|
# get the pretrained models
|
||||||
|
mkdir -p ./big-lama
|
||||||
|
wget https://huggingface.co/smartywu/big-lama/resolve/main/big-lama.zip
|
||||||
|
unzip big-lama.zip -d ./big-lama
|
||||||
|
rm big-lama.zip
|
||||||
|
|
||||||
|
# get the code to run the models
|
||||||
|
cd big-lama
|
||||||
|
git clone https://github.com/advimman/lama.git
|
||||||
|
cd lama
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|||||||
57
deblur/deblur.py
Normal file
57
deblur/deblur.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
image = np.array([1, 3, 1, 2, 1, 6, 1], dtype=np.float32)
|
||||||
|
kernel = np.array([1, 2, 1], dtype=np.float32) / 4
|
||||||
|
|
||||||
|
blurred = np.convolve(image, kernel, mode="same")
|
||||||
|
|
||||||
|
print(image)
|
||||||
|
print(blurred)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("building linalg")
|
||||||
|
# https://numpy.org/doc/stable/reference/generated/numpy.linalg.solve.html
|
||||||
|
a = []
|
||||||
|
b = []
|
||||||
|
|
||||||
|
for i in range(len(blurred)):
|
||||||
|
y = blurred[i]
|
||||||
|
|
||||||
|
shift = i - 1
|
||||||
|
equation = np.zeros(len(image))
|
||||||
|
# Calculate valid range in the output array
|
||||||
|
start_eq = max(0, shift)
|
||||||
|
end_eq = min(len(image), shift + len(kernel))
|
||||||
|
|
||||||
|
# Corresponding range in the kernel
|
||||||
|
start_k = start_eq - shift # how much to cut from the beginning of the kernel
|
||||||
|
end_k = start_k + (end_eq - start_eq)
|
||||||
|
|
||||||
|
|
||||||
|
# Assign the clipped kernel segment
|
||||||
|
equation[start_eq:end_eq] = kernel[start_k:end_k]
|
||||||
|
|
||||||
|
a.append(equation)
|
||||||
|
b.append(y)
|
||||||
|
goal = image[i]
|
||||||
|
print(f"{i} ({goal}): {y} = {equation}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("deblurring")
|
||||||
|
deblurred = np.linalg.solve(a, b)
|
||||||
|
print(deblurred)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def show_matrix(m):
|
||||||
|
# Resize the image to make it visible (e.g., scale up to 200x200 pixels)
|
||||||
|
scaled_image = cv2.resize(m, (200, 200), interpolation=cv2.INTER_NEAREST)
|
||||||
|
|
||||||
|
# Display the image
|
||||||
|
cv2.imshow('Test Matrix', scaled_image)
|
||||||
|
cv2.waitKey(0)
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
|
||||||
419
deblur/deblur_2d.py
Normal file
419
deblur/deblur_2d.py
Normal file
@@ -0,0 +1,419 @@
|
|||||||
|
import numpy as np
|
||||||
|
from scipy.signal import convolve2d
|
||||||
|
from scipy.sparse import lil_matrix
|
||||||
|
from scipy.sparse.linalg import spsolve
|
||||||
|
from scipy.optimize import curve_fit
|
||||||
|
import cv2
|
||||||
|
import matplotlib
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from pathlib import Path
|
||||||
|
from scipy.ndimage import correlate
|
||||||
|
from skimage.restoration import richardson_lucy
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
matplotlib.use('qtagg')
|
||||||
|
|
||||||
|
"""
|
||||||
|
https://setosa.io/ev/image-kernels/
|
||||||
|
https://openaccess.thecvf.com/content/CVPR2021/papers/Tran_Explore_Image_Deblurring_via_Encoded_Blur_Kernel_Space_CVPR_2021_paper.pdf
|
||||||
|
"""
|
||||||
|
|
||||||
|
def show(img):
|
||||||
|
cv2.imshow('image',img.astype(np.uint8))
|
||||||
|
cv2.waitKey(0)
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
|
||||||
|
def demo(image_file):
|
||||||
|
# Define 2D image and kernel
|
||||||
|
image = cv2.imread(image_file, 0)
|
||||||
|
image = cv2.resize(image, (200, 200), interpolation= cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
kernel = np.array([
|
||||||
|
[1, 2, 1],
|
||||||
|
[2, 4, 2],
|
||||||
|
[1, 2, 1]
|
||||||
|
], dtype=np.float32)
|
||||||
|
kernel /= kernel.sum() # Normalize
|
||||||
|
|
||||||
|
print(kernel)
|
||||||
|
|
||||||
|
# Perform 2D convolution (blurring)
|
||||||
|
blurred = convolve2d(image, kernel, mode="same", boundary="fill", fillvalue=0)
|
||||||
|
|
||||||
|
h, w = image.shape
|
||||||
|
kh, kw = kernel.shape
|
||||||
|
pad_h, pad_w = kh // 2, kw // 2
|
||||||
|
|
||||||
|
|
||||||
|
show(image)
|
||||||
|
show(blurred)
|
||||||
|
|
||||||
|
print("Original image:\n", image)
|
||||||
|
print("\nBlurred image:\n", blurred)
|
||||||
|
|
||||||
|
print("\nBuilding linear system for deconvolution...")
|
||||||
|
|
||||||
|
# Step 2: Build sparse matrix A
|
||||||
|
N = h * w
|
||||||
|
A = lil_matrix((N, N), dtype=np.float32)
|
||||||
|
b = blurred.flatten()
|
||||||
|
|
||||||
|
def index(y, x):
|
||||||
|
return y * w + x
|
||||||
|
|
||||||
|
for y in range(h):
|
||||||
|
for x in range(w):
|
||||||
|
row_idx = index(y, x)
|
||||||
|
for ky in range(kh):
|
||||||
|
for kx in range(kw):
|
||||||
|
iy = y + ky - pad_h
|
||||||
|
ix = x + kx - pad_w
|
||||||
|
if 0 <= iy < h and 0 <= ix < w:
|
||||||
|
col_idx = index(iy, ix)
|
||||||
|
A[row_idx, col_idx] += kernel[ky, kx]
|
||||||
|
|
||||||
|
# Step 3: Solve the sparse system A * x = b
|
||||||
|
x = spsolve(A.tocsr(), b)
|
||||||
|
deblurred = x.reshape((h, w))
|
||||||
|
|
||||||
|
print("\nDeblurred image:\n", np.round(deblurred, 2))
|
||||||
|
|
||||||
|
show(deblurred)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mask(image_file):
|
||||||
|
mask_file = Path(image_file)
|
||||||
|
mask_file = mask_file.with_name("mask_" + mask_file.name)
|
||||||
|
|
||||||
|
if mask_file.exists():
|
||||||
|
return cv2.imread(str(mask_file), 0)
|
||||||
|
|
||||||
|
drawing = False # True when mouse is pressed
|
||||||
|
brush_size = 5
|
||||||
|
image = cv2.imread(image_file)
|
||||||
|
mask = np.zeros(image.shape[:2], dtype=np.uint8)
|
||||||
|
clone = image.copy()
|
||||||
|
|
||||||
|
def draw_mask(event, x, y, flags, param):
|
||||||
|
nonlocal drawing, mask, brush_size
|
||||||
|
|
||||||
|
if event == cv2.EVENT_LBUTTONDOWN:
|
||||||
|
drawing = True
|
||||||
|
elif event == cv2.EVENT_MOUSEMOVE:
|
||||||
|
if drawing:
|
||||||
|
cv2.circle(mask, (x, y), brush_size, 255, -1)
|
||||||
|
cv2.circle(image, (x, y), brush_size, (0, 0, 255), -1)
|
||||||
|
elif event == cv2.EVENT_LBUTTONUP:
|
||||||
|
drawing = False
|
||||||
|
|
||||||
|
|
||||||
|
cv2.namedWindow("Draw Mask")
|
||||||
|
cv2.setMouseCallback("Draw Mask", draw_mask)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
display = image.copy()
|
||||||
|
cv2.putText(display, f'Brush size: {brush_size}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,255,0), 2)
|
||||||
|
cv2.imshow("Draw Mask", display)
|
||||||
|
key = cv2.waitKey(1) & 0xFF
|
||||||
|
if key == 13: # Enter to finish
|
||||||
|
break
|
||||||
|
elif key == ord('+') or key == ord('='): # `=` for some keyboard layouts
|
||||||
|
brush_size = min(100, brush_size + 1)
|
||||||
|
elif key == ord('-') or key == ord('_'):
|
||||||
|
brush_size = max(1, brush_size - 1)
|
||||||
|
|
||||||
|
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
cv2.imwrite(str(mask_file), mask)
|
||||||
|
|
||||||
|
# Apply mask
|
||||||
|
masked_image = cv2.bitwise_and(clone, clone, mask=mask)
|
||||||
|
|
||||||
|
cv2.imshow("Masked Image", masked_image)
|
||||||
|
cv2.waitKey(0)
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def color_edge_detection(image, threshold=30):
|
||||||
|
img_lab = cv2.cvtColor(image, cv2.COLOR_BGR2Lab)
|
||||||
|
L, A, B = cv2.split(img_lab)
|
||||||
|
|
||||||
|
def gradient_magnitude(channel):
|
||||||
|
gx = cv2.Sobel(channel, cv2.CV_64F, 1, 0, ksize=3)
|
||||||
|
gy = cv2.Sobel(channel, cv2.CV_64F, 0, 1, ksize=3)
|
||||||
|
return gx, gy
|
||||||
|
|
||||||
|
gxL, gyL = gradient_magnitude(L)
|
||||||
|
gxA, gyA = gradient_magnitude(A)
|
||||||
|
gxB, gyB = gradient_magnitude(B)
|
||||||
|
|
||||||
|
gx_total = gxL**2 + gxA**2 + gxB**2
|
||||||
|
gy_total = gyL**2 + gyA**2 + gyB**2
|
||||||
|
magnitude = np.sqrt(gx_total + gy_total)
|
||||||
|
|
||||||
|
magnitude = cv2.normalize(magnitude, None, 0, 255, cv2.NORM_MINMAX)
|
||||||
|
edges = (magnitude > threshold).astype(np.uint8) * 255
|
||||||
|
return edges, magnitude
|
||||||
|
|
||||||
|
|
||||||
|
# === Step 2: Extract Vertical Profile ===
|
||||||
|
def extract_vertical_profile(image, center_x, center_y, length=21):
|
||||||
|
half_len = length // 2
|
||||||
|
y_range = np.clip(np.arange(center_y - half_len, center_y + half_len + 1), 0, image.shape[0] - 1)
|
||||||
|
profile = image[y_range, center_x].astype(np.float64)
|
||||||
|
profile -= profile.min()
|
||||||
|
if profile.max() > 0:
|
||||||
|
profile /= profile.max()
|
||||||
|
return profile, y_range - center_y # profile, x-axis
|
||||||
|
|
||||||
|
# === Step 3: Fit Gaussian ===
|
||||||
|
def gaussian(x, amp, mu, sigma):
|
||||||
|
return amp * np.exp(-(x - mu)**2 / (2 * sigma**2))
|
||||||
|
|
||||||
|
def fit_gaussian(profile, x_vals):
|
||||||
|
p0 = [1.0, 0.0, 2.0] # initial guess: amp, mu, sigma
|
||||||
|
popt, _ = curve_fit(gaussian, x_vals, profile, p0=p0)
|
||||||
|
return popt # amp, mu, sigma
|
||||||
|
|
||||||
|
# === Step 4: Create Gaussian Kernel ===
|
||||||
|
def create_gaussian_kernel(sigma):
|
||||||
|
ksize = int(sigma * 6) | 1 # ensure odd size
|
||||||
|
kernel_1d = cv2.getGaussianKernel(ksize, sigma)
|
||||||
|
kernel_2d = kernel_1d @ kernel_1d.T
|
||||||
|
return kernel_2d
|
||||||
|
|
||||||
|
|
||||||
|
def kernel_detection(blurred, mask, edge_threshold=30, profile_length=21):
|
||||||
|
edges, gradient_mag = color_edge_detection(blurred, threshold=edge_threshold)
|
||||||
|
edges = cv2.bitwise_and(edges, edges, mask=mask)
|
||||||
|
# show(edges)
|
||||||
|
|
||||||
|
|
||||||
|
# Find central edge pixel
|
||||||
|
y_idxs, x_idxs = np.where(edges > 0)
|
||||||
|
if len(x_idxs) == 0:
|
||||||
|
raise RuntimeError("No edges found.")
|
||||||
|
idx = len(x_idxs) // 2
|
||||||
|
cx, cy = x_idxs[idx], y_idxs[idx]
|
||||||
|
|
||||||
|
gray = cv2.cvtColor(blurred, cv2.COLOR_BGR2GRAY)
|
||||||
|
profile, x_vals = extract_vertical_profile(gray, cx, cy, length=profile_length)
|
||||||
|
popt = fit_gaussian(profile, x_vals)
|
||||||
|
amp, mu, sigma = popt
|
||||||
|
|
||||||
|
print(f"Estimated Gaussian sigma: {sigma:.2f}")
|
||||||
|
|
||||||
|
kernel = create_gaussian_kernel(sigma)
|
||||||
|
|
||||||
|
# print(kernel)
|
||||||
|
|
||||||
|
return kernel / kernel.sum()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def kernel_detection_box(blurred, mask, edge_threshold=30, profile_length=21):
|
||||||
|
def box_function(x, amp, center, width):
|
||||||
|
"""Simple box profile: flat region with sharp transitions."""
|
||||||
|
return amp * ((x >= (center - width / 2)) & (x <= (center + width / 2))).astype(float)
|
||||||
|
|
||||||
|
def fit_box(profile, x_vals):
|
||||||
|
# Initial guess: full amplitude, centered at 0, small width
|
||||||
|
p0 = [1.0, 0.0, 5.0]
|
||||||
|
bounds = ([0, -10, 1], [1.5, 10, len(x_vals)]) # reasonable bounds
|
||||||
|
popt, _ = curve_fit(box_function, x_vals, profile, p0=p0, bounds=bounds)
|
||||||
|
return popt # amp, center, width
|
||||||
|
|
||||||
|
def create_box_kernel(width):
|
||||||
|
"""Generate a normalized 2D box kernel."""
|
||||||
|
ksize = int(round(width))
|
||||||
|
if ksize < 1:
|
||||||
|
ksize = 1
|
||||||
|
if ksize % 2 == 0:
|
||||||
|
ksize += 1 # ensure odd size
|
||||||
|
kernel = np.ones((ksize, ksize), dtype=np.float32)
|
||||||
|
return kernel / kernel.sum()
|
||||||
|
|
||||||
|
|
||||||
|
edges, gradient_mag = color_edge_detection(blurred, threshold=edge_threshold)
|
||||||
|
edges = cv2.bitwise_and(edges, edges, mask=mask)
|
||||||
|
|
||||||
|
y_idxs, x_idxs = np.where(edges > 0)
|
||||||
|
if len(x_idxs) == 0:
|
||||||
|
raise RuntimeError("No edges found.")
|
||||||
|
idx = len(x_idxs) // 2
|
||||||
|
cx, cy = x_idxs[idx], y_idxs[idx]
|
||||||
|
|
||||||
|
gray = cv2.cvtColor(blurred, cv2.COLOR_BGR2GRAY)
|
||||||
|
profile, x_vals = extract_vertical_profile(gray, cx, cy, length=profile_length)
|
||||||
|
popt = fit_box(profile, x_vals)
|
||||||
|
amp, mu, width = popt
|
||||||
|
|
||||||
|
print(f"Estimated box width: {width:.2f} pixels")
|
||||||
|
|
||||||
|
kernel = create_box_kernel(width)
|
||||||
|
return kernel
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def deconvolution(image_file, edge_threshold=30, profile_length=21):
|
||||||
|
image = cv2.imread(image_file)
|
||||||
|
mask = get_mask(image_file)
|
||||||
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
||||||
|
|
||||||
|
kernel = kernel_detection_box(image, mask, edge_threshold=edge_threshold, profile_length=profile_length)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply Richardson-Lucy to each channel
|
||||||
|
num_iter = 30
|
||||||
|
deblurred_channels = []
|
||||||
|
for i in range(3): # R, G, B
|
||||||
|
channel = image_rgb[..., i]
|
||||||
|
deconv = richardson_lucy(channel, kernel, num_iter=num_iter)
|
||||||
|
deblurred_channels.append(deconv)
|
||||||
|
|
||||||
|
# Stack back into an RGB image
|
||||||
|
deblurred_rgb = np.stack(deblurred_channels, axis=2)
|
||||||
|
deblurred_rgb = np.clip(deblurred_rgb, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
# Show result
|
||||||
|
plt.figure(figsize=(10, 5))
|
||||||
|
plt.subplot(1, 2, 1)
|
||||||
|
plt.imshow(image_rgb)
|
||||||
|
plt.title("Blurred Image")
|
||||||
|
plt.axis('off')
|
||||||
|
|
||||||
|
plt.subplot(1, 2, 2)
|
||||||
|
plt.imshow(deblurred_rgb)
|
||||||
|
plt.title("Deconvolved Image")
|
||||||
|
plt.axis('off')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def sharpness_heatmap(image, block_size=32, threshold=30):
|
||||||
|
"""
|
||||||
|
Compute a sharpness heatmap using color-aware Laplacian variance over blocks,
|
||||||
|
generate a binary mask highlighting blurred areas, and smooth the edges of the mask.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: BGR or RGB image (NumPy array).
|
||||||
|
block_size: Size of the square block to compute sharpness.
|
||||||
|
sigma: Standard deviation for Gaussian smoothing of the heatmap.
|
||||||
|
threshold: Sharpness threshold to define blurred regions (between 0 and 1).
|
||||||
|
smoothing_sigma: Standard deviation for Gaussian smoothing of the binary mask edges.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
blurred_mask: Binary mask highlighting blurred areas (0 = sharp, 255 = blurred).
|
||||||
|
"""
|
||||||
|
if image.ndim != 3 or image.shape[2] != 3:
|
||||||
|
raise ValueError("Input must be a color image (3 channels)")
|
||||||
|
|
||||||
|
h, w, _ = image.shape
|
||||||
|
heatmap = np.zeros((h // block_size, w // block_size))
|
||||||
|
|
||||||
|
# Calculate sharpness for each block
|
||||||
|
for y in range(0, h - block_size + 1, block_size):
|
||||||
|
for x in range(0, w - block_size + 1, block_size):
|
||||||
|
block = image[y:y + block_size, x:x + block_size, :]
|
||||||
|
sharpness_vals = []
|
||||||
|
|
||||||
|
for c in range(3): # For R, G, B channels
|
||||||
|
channel = block[..., c]
|
||||||
|
lap_var = cv2.Laplacian(channel, cv2.CV_64F).var()
|
||||||
|
sharpness_vals.append(lap_var)
|
||||||
|
|
||||||
|
# Use average sharpness across color channels
|
||||||
|
heatmap[y // block_size, x // block_size] = np.mean(sharpness_vals)
|
||||||
|
|
||||||
|
print(heatmap)
|
||||||
|
|
||||||
|
# Threshold the heatmap to create a binary mask (blurred regions)
|
||||||
|
mask = heatmap < threshold
|
||||||
|
mask = (mask * 255).astype(np.uint8) # Convert to binary mask (0, 255)
|
||||||
|
|
||||||
|
# Display Heatmap
|
||||||
|
plt.subplot(1, 2, 1)
|
||||||
|
plt.imshow(heatmap, cmap='hot', interpolation='nearest')
|
||||||
|
plt.title("Sharpness Heatmap")
|
||||||
|
plt.colorbar(label='Sharpness')
|
||||||
|
|
||||||
|
# Display Smoothed Mask
|
||||||
|
plt.subplot(1, 2, 2)
|
||||||
|
plt.imshow(mask, cmap='gray', interpolation='nearest')
|
||||||
|
plt.title("Smoothed Mask for Blurred Areas")
|
||||||
|
plt.colorbar(label='Blurred Mask')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
return smoothed_mask
|
||||||
|
|
||||||
|
|
||||||
|
def graininess_heatmap(image, block_size=32, threshold=100):
|
||||||
|
"""
|
||||||
|
Compute a graininess heatmap using local variance (texture/noise) over blocks.
|
||||||
|
No smoothing or blurring is applied.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: BGR or RGB image (NumPy array).
|
||||||
|
block_size: Size of the square block to compute variance (graininess).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
graininess_map: Heatmap highlighting the graininess (texture/noise) in the image.
|
||||||
|
"""
|
||||||
|
if image.ndim != 3 or image.shape[2] != 3:
|
||||||
|
raise ValueError("Input must be a color image (3 channels)")
|
||||||
|
|
||||||
|
h, w, _ = image.shape
|
||||||
|
graininess_map = np.zeros((h // block_size, w // block_size))
|
||||||
|
|
||||||
|
# Calculate variance for each block
|
||||||
|
for y in range(0, h - block_size + 1, block_size):
|
||||||
|
for x in range(0, w - block_size + 1, block_size):
|
||||||
|
block = image[y:y + block_size, x:x + block_size, :]
|
||||||
|
variance_vals = []
|
||||||
|
|
||||||
|
for c in range(3): # For R, G, B channels
|
||||||
|
channel = block[..., c]
|
||||||
|
variance = np.var(channel)
|
||||||
|
variance_vals.append(variance)
|
||||||
|
|
||||||
|
# Use average variance across color channels for graininess
|
||||||
|
graininess_map[y // block_size, x // block_size] = np.mean(variance_vals)
|
||||||
|
|
||||||
|
|
||||||
|
mask = graininess_map < threshold
|
||||||
|
mask = (mask * 255).astype(np.uint8) # Convert to binary mask (0, 255)
|
||||||
|
|
||||||
|
# Display graininess_map
|
||||||
|
plt.subplot(1, 2, 1)
|
||||||
|
plt.imshow(graininess_map, cmap='hot', interpolation='nearest')
|
||||||
|
plt.title("Graininess Heatmap")
|
||||||
|
plt.colorbar(label='Graininess')
|
||||||
|
|
||||||
|
# Display Smoothed Mask
|
||||||
|
plt.subplot(1, 2, 2)
|
||||||
|
plt.imshow(mask, cmap='gray', interpolation='nearest')
|
||||||
|
plt.title("Mask for Blurred Areas")
|
||||||
|
plt.colorbar(label='Blurred Mask')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
return graininess_map
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
img_file = "assets/real_test.jpg"
|
||||||
|
|
||||||
|
#demo("assets/omas.png")
|
||||||
|
# deconvolution(img_file, edge_threshold=5)
|
||||||
|
|
||||||
|
image = cv2.imread(img_file)
|
||||||
|
test = graininess_heatmap(image)
|
||||||
|
heatmap = sharpness_heatmap(image)
|
||||||
251
deblur/symetric_kernel.py
Normal file
251
deblur/symetric_kernel.py
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
from PyQt5.QtWidgets import (
|
||||||
|
QApplication, QWidget, QLabel, QSlider, QVBoxLayout,
|
||||||
|
QHBoxLayout, QGridLayout, QPushButton, QFileDialog
|
||||||
|
)
|
||||||
|
from PyQt5.QtCore import Qt
|
||||||
|
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
import scipy.signal
|
||||||
|
from scipy.signal import convolve2d
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ.pop("QT_QPA_PLATFORM_PLUGIN_PATH", None)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_box_kernel(size):
|
||||||
|
return np.ones((size, size), dtype=np.float32) / (size * size)
|
||||||
|
|
||||||
|
def generate_disk_kernel(radius):
|
||||||
|
size = 2 * radius + 1
|
||||||
|
y, x = np.ogrid[-radius:radius+1, -radius:radius+1]
|
||||||
|
mask = x**2 + y**2 <= radius**2
|
||||||
|
kernel = np.zeros((size, size), dtype=np.float32)
|
||||||
|
kernel[mask] = 1
|
||||||
|
kernel /= kernel.sum()
|
||||||
|
return kernel
|
||||||
|
|
||||||
|
def generate_kernel(radius, sigma=None):
|
||||||
|
"""
|
||||||
|
Generate a 2D Gaussian kernel with a given radius.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- radius: int, the radius of the kernel (size will be 2*radius + 1)
|
||||||
|
- sigma: float (optional), standard deviation of the Gaussian. If None, sigma = radius / 3
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- kernel: 2D numpy array of shape (2*radius+1, 2*radius+1)
|
||||||
|
"""
|
||||||
|
size = 2 * radius + 1
|
||||||
|
if sigma is None:
|
||||||
|
sigma = radius / 3.0 # Common default choice
|
||||||
|
print(f"radius: {radius}, sigma: {sigma}")
|
||||||
|
|
||||||
|
# Create a grid of (x,y) coordinates
|
||||||
|
ax = np.arange(-radius, radius + 1)
|
||||||
|
xx, yy = np.meshgrid(ax, ax)
|
||||||
|
|
||||||
|
# Apply the 2D Gaussian formula
|
||||||
|
kernel = np.exp(-(xx**2 + yy**2) / (2 * sigma**2))
|
||||||
|
kernel /= 2 * np.pi * sigma**2 # Normalize based on Gaussian PDF
|
||||||
|
kernel /= kernel.sum() # Normalize to sum to 1
|
||||||
|
|
||||||
|
return kernel
|
||||||
|
|
||||||
|
|
||||||
|
def wiener_deconvolution(blurred, kernel, K=0.1):
|
||||||
|
"""
|
||||||
|
Perform Wiener deconvolution on a 2D image.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- blurred: 2D numpy array (blurred image)
|
||||||
|
- kernel: 2D numpy array (PSF / blur kernel)
|
||||||
|
- K: float, estimated noise-to-signal ratio
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- deconvolved: 2D numpy array (deblurred image)
|
||||||
|
"""
|
||||||
|
# Pad kernel to image size
|
||||||
|
kernel /= np.sum(kernel)
|
||||||
|
pad = [(0, blurred.shape[0] - kernel.shape[0]),
|
||||||
|
(0, blurred.shape[1] - kernel.shape[1])]
|
||||||
|
kernel_padded = np.pad(kernel, pad, 'constant')
|
||||||
|
|
||||||
|
# FFT of image and kernel
|
||||||
|
H = np.fft.fft2(kernel_padded)
|
||||||
|
G = np.fft.fft2(blurred)
|
||||||
|
|
||||||
|
# Avoid division by zero
|
||||||
|
H_conj = np.conj(H)
|
||||||
|
denominator = H_conj * H + K
|
||||||
|
F_hat = H_conj / denominator * G
|
||||||
|
|
||||||
|
# Inverse FFT to get result
|
||||||
|
deconvolved = np.fft.ifft2(F_hat)
|
||||||
|
deconvolved = np.abs(deconvolved)
|
||||||
|
deconvolved = np.clip(deconvolved, 0, 255)
|
||||||
|
return deconvolved.astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
def richardson_lucy(image, psf, iterations=30, clip=True):
|
||||||
|
image = image.astype(np.float32) + 1e-6
|
||||||
|
psf = psf / psf.sum()
|
||||||
|
estimate = np.full(image.shape, 0.5, dtype=np.float32)
|
||||||
|
|
||||||
|
psf_mirror = psf[::-1, ::-1]
|
||||||
|
|
||||||
|
for _ in range(iterations):
|
||||||
|
conv = convolve2d(estimate, psf, mode='same', boundary='wrap')
|
||||||
|
relative_blur = image / (conv + 1e-6)
|
||||||
|
estimate *= convolve2d(relative_blur, psf_mirror, mode='same', boundary='wrap')
|
||||||
|
|
||||||
|
if clip:
|
||||||
|
estimate = np.clip(estimate, 0, 255)
|
||||||
|
|
||||||
|
return estimate
|
||||||
|
|
||||||
|
|
||||||
|
class KernelVisualizer(QWidget):
|
||||||
|
def __init__(self, image_path=None):
|
||||||
|
super().__init__()
|
||||||
|
self.setWindowTitle("Gaussian Kernel Visualizer")
|
||||||
|
self.image = None
|
||||||
|
self.deconvolved = None
|
||||||
|
|
||||||
|
self.load_button = QPushButton("Load Image")
|
||||||
|
self.load_button.clicked.connect(self.load_image)
|
||||||
|
|
||||||
|
self.radius_slider = QSlider(Qt.Horizontal)
|
||||||
|
self.radius_slider.setRange(1, 100)
|
||||||
|
self.radius_slider.setValue(5)
|
||||||
|
|
||||||
|
self.sigma_slider = QSlider(Qt.Horizontal)
|
||||||
|
self.sigma_slider.setRange(1, 300)
|
||||||
|
self.sigma_slider.setValue(15)
|
||||||
|
|
||||||
|
self.radius_slider.valueChanged.connect(self.update_visualization)
|
||||||
|
self.sigma_slider.valueChanged.connect(self.update_visualization)
|
||||||
|
|
||||||
|
self.kernel_fig = Figure(figsize=(3, 3))
|
||||||
|
self.kernel_canvas = FigureCanvas(self.kernel_fig)
|
||||||
|
|
||||||
|
self.image_fig = Figure(figsize=(6, 3))
|
||||||
|
self.image_canvas = FigureCanvas(self.image_fig)
|
||||||
|
|
||||||
|
self.iter_slider = QSlider(Qt.Horizontal)
|
||||||
|
self.iter_slider.setRange(1, 50)
|
||||||
|
self.iter_slider.setValue(10)
|
||||||
|
|
||||||
|
self.apply_button = QPushButton("Do Deconvolution.")
|
||||||
|
self.apply_button.clicked.connect(self.apply_kernel)
|
||||||
|
|
||||||
|
layout = QVBoxLayout()
|
||||||
|
layout.addWidget(self.load_button)
|
||||||
|
|
||||||
|
|
||||||
|
sliders_layout = QGridLayout()
|
||||||
|
sliders_layout.addWidget(QLabel("Radius:"), 0, 0)
|
||||||
|
sliders_layout.addWidget(self.radius_slider, 0, 1)
|
||||||
|
sliders_layout.addWidget(QLabel("Sigma:"), 1, 0)
|
||||||
|
sliders_layout.addWidget(self.sigma_slider, 1, 1)
|
||||||
|
|
||||||
|
sliders_layout.addWidget(QLabel("Iterations:"), 2, 0)
|
||||||
|
sliders_layout.addWidget(self.iter_slider, 2, 1)
|
||||||
|
sliders_layout.addWidget(self.apply_button, 3, 1)
|
||||||
|
|
||||||
|
layout.addLayout(sliders_layout)
|
||||||
|
layout.addWidget(QLabel("Kernel Visualization:"))
|
||||||
|
layout.addWidget(self.kernel_canvas)
|
||||||
|
layout.addWidget(QLabel("Original and Deconvolved Image:"))
|
||||||
|
layout.addWidget(self.image_canvas)
|
||||||
|
|
||||||
|
self.setLayout(layout)
|
||||||
|
|
||||||
|
if image_path:
|
||||||
|
self.load_image(image_path)
|
||||||
|
else:
|
||||||
|
self.update_visualization()
|
||||||
|
|
||||||
|
def load_image(self, image_path=None):
|
||||||
|
if not image_path:
|
||||||
|
fname, _ = QFileDialog.getOpenFileName(self, "Open Image", "", "Images (*.png *.jpg *.bmp *.jpeg)")
|
||||||
|
image_path = fname
|
||||||
|
|
||||||
|
if image_path:
|
||||||
|
img = cv2.imread(image_path)
|
||||||
|
if img is not None:
|
||||||
|
self.image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||||
|
self.update_visualization()
|
||||||
|
|
||||||
|
def load_image(self, image_path=None):
|
||||||
|
if not image_path:
|
||||||
|
fname, _ = QFileDialog.getOpenFileName(self, "Open Image", "", "Images (*.png *.jpg *.bmp *.jpeg)")
|
||||||
|
image_path = fname
|
||||||
|
|
||||||
|
if image_path:
|
||||||
|
img = cv2.imread(image_path)
|
||||||
|
if img is not None:
|
||||||
|
self.image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||||
|
self.image = cv2.resize(self.image, (200, 200))
|
||||||
|
self.update_visualization()
|
||||||
|
|
||||||
|
def apply_kernel(self):
|
||||||
|
radius = self.radius_slider.value()
|
||||||
|
sigma = self.sigma_slider.value() / 10.0
|
||||||
|
iterations = self.iter_slider.value()
|
||||||
|
|
||||||
|
kernel = generate_kernel(radius, sigma)
|
||||||
|
|
||||||
|
self.deconvolved = richardson_lucy(self.image, kernel, iterations=iterations)
|
||||||
|
|
||||||
|
self.update_visualization()
|
||||||
|
|
||||||
|
def update_visualization(self):
|
||||||
|
radius = self.radius_slider.value()
|
||||||
|
sigma = self.sigma_slider.value() / 10.0 * (radius / 3)
|
||||||
|
kernel = generate_kernel(radius, sigma)
|
||||||
|
iterations = self.iter_slider.value()
|
||||||
|
|
||||||
|
|
||||||
|
# Kernel Visualization
|
||||||
|
self.kernel_fig.clear()
|
||||||
|
ax = self.kernel_fig.add_subplot(111)
|
||||||
|
cax = ax.imshow(kernel, cmap='hot')
|
||||||
|
self.kernel_fig.colorbar(cax, ax=ax)
|
||||||
|
ax.set_title(f"Kernel (r={radius}, σ={sigma:.2f})")
|
||||||
|
self.kernel_canvas.draw()
|
||||||
|
|
||||||
|
if self.image is not None:
|
||||||
|
self.image_fig.clear()
|
||||||
|
ax1 = self.image_fig.add_subplot(131)
|
||||||
|
ax1.imshow(self.image, cmap='gray')
|
||||||
|
ax1.set_title("Original")
|
||||||
|
ax1.axis('off')
|
||||||
|
|
||||||
|
if self.deconvolved is not None:
|
||||||
|
ax3 = self.image_fig.add_subplot(133)
|
||||||
|
ax3.imshow(self.deconvolved, cmap='gray')
|
||||||
|
ax3.set_title(f"Deconvolved (RL, {iterations} iter)")
|
||||||
|
ax3.axis('off')
|
||||||
|
|
||||||
|
self.image_canvas.draw()
|
||||||
|
else:
|
||||||
|
self.image_fig.clear()
|
||||||
|
ax = self.image_fig.add_subplot(111)
|
||||||
|
ax.text(0.5, 0.5, "No image loaded", fontsize=14, ha='center', va='center')
|
||||||
|
ax.axis('off')
|
||||||
|
self.image_canvas.draw()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
image_path = None
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
image_path = sys.argv[1] # Get image path from command-line argument
|
||||||
|
print(image_path)
|
||||||
|
|
||||||
|
app = QApplication(sys.argv)
|
||||||
|
viewer = KernelVisualizer(image_path=image_path)
|
||||||
|
viewer.show()
|
||||||
|
sys.exit(app.exec_())
|
||||||
@@ -2,7 +2,14 @@
|
|||||||
name = "secure_pixelation"
|
name = "secure_pixelation"
|
||||||
version = "0.0.0"
|
version = "0.0.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"torch==2.1.2",
|
||||||
|
"torchvision==0.16.2",
|
||||||
|
|
||||||
"opencv_python~=4.11.0.86",
|
"opencv_python~=4.11.0.86",
|
||||||
|
"numpy<2.0.0",
|
||||||
|
"hf_transfer==0.1.8",
|
||||||
|
"huggingface_hub==0.25.1",
|
||||||
|
|
||||||
"ultralytics~=8.3.114",
|
"ultralytics~=8.3.114",
|
||||||
]
|
]
|
||||||
authors = []
|
authors = []
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
from .get_bounding_boxes import select_bounding_boxes
|
from .get_bounding_boxes import select_bounding_boxes
|
||||||
|
from .pixelation_process import pixelate
|
||||||
|
|
||||||
|
|
||||||
def cli():
|
def cli():
|
||||||
print(f"Running secure_pixelation")
|
print(f"Running secure_pixelation")
|
||||||
|
|
||||||
select_bounding_boxes("assets/human_detection/test.png")
|
pixelate("assets/human_detection/test.png", generative_impaint=True)
|
||||||
|
pixelate("assets/human_detection/humans.png", generative_impaint=False)
|
||||||
|
pixelate("assets/human_detection/rev1.png", generative_impaint=False)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Union, List
|
from typing import Union, List, Tuple
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@@ -17,13 +17,18 @@ class RawImage:
|
|||||||
self.meta_data = self.read_meta()
|
self.meta_data = self.read_meta()
|
||||||
|
|
||||||
self.image = self.get_image()
|
self.image = self.get_image()
|
||||||
|
|
||||||
def _get_path(self, ending: str, original_suffix: bool = False) -> Path:
|
def _get_path(self, ending: str, original_suffix: bool = False) -> Path:
|
||||||
if original_suffix:
|
if original_suffix:
|
||||||
return self.file.with_name(self.file.stem + "_" + ending + self.file.suffix)
|
return self.file.with_name(self.file.stem + "_" + ending + self.file.suffix)
|
||||||
else:
|
else:
|
||||||
return self.file.with_name(self.file.stem + "_" + ending)
|
return self.file.with_name(self.file.stem + "_" + ending)
|
||||||
|
|
||||||
|
def get_dir(self, name: str) -> Path:
|
||||||
|
p = self._get_path(ending=name, original_suffix=False)
|
||||||
|
p.mkdir(exist_ok=True, parents=True)
|
||||||
|
return p
|
||||||
|
|
||||||
def read_meta(self) -> dict:
|
def read_meta(self) -> dict:
|
||||||
if not self.meta_file.exists():
|
if not self.meta_file.exists():
|
||||||
return {}
|
return {}
|
||||||
|
|||||||
@@ -18,5 +18,5 @@ def select_bounding_boxes(to_detect: str):
|
|||||||
fromCenter=False
|
fromCenter=False
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_image.bounding_boxes.extend(bounding_boxes)
|
raw_image.bounding_boxes.extend(bounding_boxes.tolist())
|
||||||
raw_image.write_meta()
|
raw_image.write_meta()
|
||||||
|
|||||||
95
secure_pixelation/pixelation_process.py
Normal file
95
secure_pixelation/pixelation_process.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from pathlib import Path
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .data_classes import RawImage
|
||||||
|
from .simple_lama_bindings import SimpleLama
|
||||||
|
# https://github.com/okaris/simple-lama/tree/main
|
||||||
|
|
||||||
|
def blackout(raw_image: RawImage) -> np.ndarray:
|
||||||
|
image = raw_image.get_image()
|
||||||
|
|
||||||
|
for box in raw_image.bounding_boxes:
|
||||||
|
cv2.rectangle(image, box, (0, 0, 0), -1)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def get_mask(raw_image: RawImage) -> np.ndarray:
|
||||||
|
mask = np.zeros(raw_image.image.shape[:2], dtype=np.uint8)
|
||||||
|
for (x, y, w, h) in raw_image.bounding_boxes:
|
||||||
|
mask[y:y+h, x:x+w] = 255
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def quick_impaint(raw_image: RawImage, image: Optional[np.ndarray] = None) -> np.ndarray:
|
||||||
|
image = image if image is not None else raw_image.get_image()
|
||||||
|
|
||||||
|
mask = get_mask(raw_image)
|
||||||
|
|
||||||
|
# Apply inpainting using the Telea method
|
||||||
|
return cv2.inpaint(image, mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
|
||||||
|
|
||||||
|
|
||||||
|
def do_generative_impaint(raw_image: RawImage, image: Optional[np.ndarray] = None) -> np.ndarray:
|
||||||
|
image = image if image is not None else raw_image.get_image()
|
||||||
|
mask = get_mask(raw_image)
|
||||||
|
|
||||||
|
lama = SimpleLama()
|
||||||
|
return lama(image=image, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def pixelate_regions(raw_image: RawImage, image: Optional[np.ndarray] = None, pixel_size: int = 10) -> np.ndarray:
|
||||||
|
image = image.copy() if image is not None else raw_image.get_image().copy()
|
||||||
|
|
||||||
|
for (x, y, w, h) in raw_image.bounding_boxes:
|
||||||
|
roi = image[y:y+h, x:x+w]
|
||||||
|
|
||||||
|
# Resize down and then back up
|
||||||
|
temp = cv2.resize(roi, (max(1, w // pixel_size), max(1, h // pixel_size)), interpolation=cv2.INTER_LINEAR)
|
||||||
|
pixelated = cv2.resize(temp, (w, h), interpolation=cv2.INTER_NEAREST)
|
||||||
|
|
||||||
|
image[y:y+h, x:x+w] = pixelated
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def pixelate(to_detect: str, generative_impaint: bool = True, debug_drawings: bool = False):
|
||||||
|
raw_image = RawImage(to_detect)
|
||||||
|
|
||||||
|
step_dir = raw_image.get_dir("steps")
|
||||||
|
def write_image(image: np.ndarray, name: str):
|
||||||
|
nonlocal debug_drawings
|
||||||
|
|
||||||
|
f = str(step_dir / (name + raw_image.file.suffix))
|
||||||
|
|
||||||
|
if debug_drawings:
|
||||||
|
for box in raw_image.bounding_boxes:
|
||||||
|
cv2.rectangle(image, box, (0, 255, 255), 1)
|
||||||
|
|
||||||
|
cv2.imwrite(f, image)
|
||||||
|
|
||||||
|
write_image(raw_image.image, "step_0")
|
||||||
|
|
||||||
|
step_1 = blackout(raw_image)
|
||||||
|
write_image(step_1, "step_1")
|
||||||
|
|
||||||
|
if generative_impaint:
|
||||||
|
step_2 = do_generative_impaint(raw_image, image=step_1)
|
||||||
|
step_2_alt = quick_impaint(raw_image, image=step_1)
|
||||||
|
else:
|
||||||
|
step_2 = quick_impaint(raw_image, image=step_1)
|
||||||
|
step_2_alt = do_generative_impaint(raw_image, image=step_1)
|
||||||
|
write_image(step_2, "step_2")
|
||||||
|
write_image(step_2_alt, "step_2_alt")
|
||||||
|
|
||||||
|
step_3 = pixelate_regions(raw_image, image=step_2)
|
||||||
|
write_image(step_3, "step_3")
|
||||||
77
secure_pixelation/simple_lama_bindings.py
Normal file
77
secure_pixelation/simple_lama_bindings.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import os
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
# https://github.com/okaris/simple-lama/blob/main/src/simple_lama/simple_lama.py
|
||||||
|
|
||||||
|
|
||||||
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||||
|
|
||||||
|
def prepare_img_and_mask(image: np.ndarray, mask: np.ndarray, device: torch.device, pad_out_to_modulo: int = 8, scale_factor: float = 1) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
def get_image(img: np.ndarray):
|
||||||
|
img = img.copy()
|
||||||
|
|
||||||
|
if img.ndim == 3:
|
||||||
|
img = np.transpose(img, (2, 0, 1)) # chw
|
||||||
|
elif img.ndim == 2:
|
||||||
|
img = img[np.newaxis, ...]
|
||||||
|
|
||||||
|
return img.astype(np.float32) / 255
|
||||||
|
|
||||||
|
def scale_image(img: np.ndarray, factor: float, interpolation=cv2.INTER_AREA) -> np.ndarray:
|
||||||
|
if img.shape[0] == 1:
|
||||||
|
img = img[0]
|
||||||
|
else:
|
||||||
|
img = np.transpose(img, (1, 2, 0))
|
||||||
|
|
||||||
|
img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation)
|
||||||
|
return img[None, ...] if img.ndim == 2 else np.transpose(img, (2, 0, 1))
|
||||||
|
|
||||||
|
def pad_img_to_modulo(img, mod):
|
||||||
|
channels, height, width = img.shape
|
||||||
|
out_height = height if height % mod == 0 else ((height // mod + 1) * mod)
|
||||||
|
out_width = width if width % mod == 0 else ((width // mod + 1) * mod)
|
||||||
|
return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode="symmetric")
|
||||||
|
|
||||||
|
out_image = get_image(image)
|
||||||
|
out_mask = get_image(mask)
|
||||||
|
|
||||||
|
if scale_factor != 1:
|
||||||
|
out_image = scale_image(out_image, scale_factor)
|
||||||
|
out_mask = scale_image(out_mask, scale_factor, interpolation=cv2.INTER_NEAREST)
|
||||||
|
|
||||||
|
if pad_out_to_modulo > 1:
|
||||||
|
out_image = pad_img_to_modulo(out_image, pad_out_to_modulo)
|
||||||
|
out_mask = pad_img_to_modulo(out_mask, pad_out_to_modulo)
|
||||||
|
|
||||||
|
out_image = torch.from_numpy(out_image).unsqueeze(0).to(device)
|
||||||
|
out_mask = torch.from_numpy(out_mask).unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
return out_image, (out_mask > 0) * 1
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleLama:
|
||||||
|
"""
|
||||||
|
lama = SimpleLama()
|
||||||
|
result = lama(image, mask)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device=None):
|
||||||
|
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Using device: {self.device}")
|
||||||
|
|
||||||
|
model_path = hf_hub_download("okaris/simple-lama", "big-lama.pt")
|
||||||
|
print(f"using model at {model_path}")
|
||||||
|
self.model = torch.jit.load(model_path, map_location=self.device).eval()
|
||||||
|
|
||||||
|
def __call__(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
||||||
|
image, mask = prepare_img_and_mask(image, mask, self.device)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
inpainted = self.model(image, mask)
|
||||||
|
cur_res = inpainted[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||||
|
return np.clip(cur_res * 255, 0, 255).astype(np.uint8)
|
||||||
Reference in New Issue
Block a user