feat: deconvolution

This commit is contained in:
Hazel Noack 2025-05-07 12:38:25 +02:00
parent b4c7512a73
commit 6101a8d5e4

View File

@ -2,12 +2,18 @@ import numpy as np
from scipy.signal import convolve2d
from scipy.sparse import lil_matrix
from scipy.sparse.linalg import spsolve
from scipy.optimize import curve_fit
import cv2
import matplotlib
import matplotlib.pyplot as plt
from pathlib import Path
from scipy.ndimage import correlate
from skimage.restoration import richardson_lucy
import os
matplotlib.use('qtagg')
"""
https://setosa.io/ev/image-kernels/
https://openaccess.thecvf.com/content/CVPR2021/papers/Tran_Explore_Image_Deblurring_via_Encoded_Blur_Kernel_Space_CVPR_2021_paper.pdf
@ -132,14 +138,10 @@ def get_mask(image_file):
def color_edge_detection(img, threshold=30):
# Load image
img_lab = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)
# Split Lab channels
def color_edge_detection(image, threshold=30):
img_lab = cv2.cvtColor(image, cv2.COLOR_BGR2Lab)
L, A, B = cv2.split(img_lab)
# Compute gradients using Sobel for each channel
def gradient_magnitude(channel):
gx = cv2.Sobel(channel, cv2.CV_64F, 1, 0, ksize=3)
gy = cv2.Sobel(channel, cv2.CV_64F, 0, 1, ksize=3)
@ -149,28 +151,100 @@ def color_edge_detection(img, threshold=30):
gxA, gyA = gradient_magnitude(A)
gxB, gyB = gradient_magnitude(B)
# Combine gradients across channels
gx_total = gxL**2 + gxA**2 + gxB**2
gy_total = gyL**2 + gyA**2 + gyB**2
magnitude = np.sqrt(gx_total + gy_total)
# Normalize and threshold
magnitude = cv2.normalize(magnitude, None, 0, 255, cv2.NORM_MINMAX)
edges = (magnitude > threshold).astype(np.uint8) * 255
return edges, magnitude
return edges
# === Step 2: Extract Vertical Profile ===
def extract_vertical_profile(image, center_x, center_y, length=21):
half_len = length // 2
y_range = np.clip(np.arange(center_y - half_len, center_y + half_len + 1), 0, image.shape[0] - 1)
profile = image[y_range, center_x].astype(np.float64)
profile -= profile.min()
if profile.max() > 0:
profile /= profile.max()
return profile, y_range - center_y # profile, x-axis
# === Step 3: Fit Gaussian ===
def gaussian(x, amp, mu, sigma):
return amp * np.exp(-(x - mu)**2 / (2 * sigma**2))
def fit_gaussian(profile, x_vals):
p0 = [1.0, 0.0, 2.0] # initial guess: amp, mu, sigma
popt, _ = curve_fit(gaussian, x_vals, profile, p0=p0)
return popt # amp, mu, sigma
# === Step 4: Create Gaussian Kernel ===
def create_gaussian_kernel(sigma):
ksize = int(sigma * 6) | 1 # ensure odd size
kernel_1d = cv2.getGaussianKernel(ksize, sigma)
kernel_2d = kernel_1d @ kernel_1d.T
return kernel_2d
def kernel_detection(blurred, mask, edge_threshold=30, profile_length=21):
edges, gradient_mag = color_edge_detection(blurred, threshold=edge_threshold)
edges = cv2.bitwise_and(edges, edges, mask=mask)
# show(edges)
# Find central edge pixel
y_idxs, x_idxs = np.where(edges > 0)
if len(x_idxs) == 0:
raise RuntimeError("No edges found.")
idx = len(x_idxs) // 2
cx, cy = x_idxs[idx], y_idxs[idx]
gray = cv2.cvtColor(blurred, cv2.COLOR_BGR2GRAY)
profile, x_vals = extract_vertical_profile(gray, cx, cy, length=profile_length)
popt = fit_gaussian(profile, x_vals)
amp, mu, sigma = popt
print(f"Estimated Gaussian sigma: {sigma:.2f}")
kernel = create_gaussian_kernel(sigma)
print(kernel)
return kernel / kernel.sum()
def deconvolution(image_file):
def deconvolution(image_file, edge_threshold=30, profile_length=21):
blurred = cv2.imread(image_file)
# blurred = cv2.resize(blurred, (200, 200), interpolation= cv2.INTER_LINEAR)
mask = get_mask(image_file)
edges = color_edge_detection(blurred, threshold=8)
# edges = cv2.bitwise_and(edges, edges, mask=mask)
show(edges)
kernel = kernel_detection(blurred, mask, edge_threshold=edge_threshold, profile_length=profile_length)
test_blurred = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.0
deconvolved = richardson_lucy(test_blurred, kernel, num_iter=30)
# Display results
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.title("Blurred Image")
plt.imshow(test_blurred, cmap='gray')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.title("Estimated Kernel")
plt.imshow(kernel, cmap='hot')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.title("Deconvolved Image")
plt.imshow(deconvolved, cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.show()
@ -179,4 +253,4 @@ if __name__ == "__main__":
img_file = "assets/real_test.jpg"
#demo("assets/omas.png")
deconvolution(img_file)
deconvolution(img_file, edge_threshold=10)