feat: deconvolution

This commit is contained in:
Hazel Noack 2025-05-07 12:38:25 +02:00
parent b4c7512a73
commit 6101a8d5e4

View File

@ -2,12 +2,18 @@ import numpy as np
from scipy.signal import convolve2d from scipy.signal import convolve2d
from scipy.sparse import lil_matrix from scipy.sparse import lil_matrix
from scipy.sparse.linalg import spsolve from scipy.sparse.linalg import spsolve
from scipy.optimize import curve_fit
import cv2 import cv2
import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from pathlib import Path from pathlib import Path
from scipy.ndimage import correlate from scipy.ndimage import correlate
from skimage.restoration import richardson_lucy
import os
matplotlib.use('qtagg')
""" """
https://setosa.io/ev/image-kernels/ https://setosa.io/ev/image-kernels/
https://openaccess.thecvf.com/content/CVPR2021/papers/Tran_Explore_Image_Deblurring_via_Encoded_Blur_Kernel_Space_CVPR_2021_paper.pdf https://openaccess.thecvf.com/content/CVPR2021/papers/Tran_Explore_Image_Deblurring_via_Encoded_Blur_Kernel_Space_CVPR_2021_paper.pdf
@ -132,14 +138,10 @@ def get_mask(image_file):
def color_edge_detection(img, threshold=30): def color_edge_detection(image, threshold=30):
# Load image img_lab = cv2.cvtColor(image, cv2.COLOR_BGR2Lab)
img_lab = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)
# Split Lab channels
L, A, B = cv2.split(img_lab) L, A, B = cv2.split(img_lab)
# Compute gradients using Sobel for each channel
def gradient_magnitude(channel): def gradient_magnitude(channel):
gx = cv2.Sobel(channel, cv2.CV_64F, 1, 0, ksize=3) gx = cv2.Sobel(channel, cv2.CV_64F, 1, 0, ksize=3)
gy = cv2.Sobel(channel, cv2.CV_64F, 0, 1, ksize=3) gy = cv2.Sobel(channel, cv2.CV_64F, 0, 1, ksize=3)
@ -149,28 +151,100 @@ def color_edge_detection(img, threshold=30):
gxA, gyA = gradient_magnitude(A) gxA, gyA = gradient_magnitude(A)
gxB, gyB = gradient_magnitude(B) gxB, gyB = gradient_magnitude(B)
# Combine gradients across channels
gx_total = gxL**2 + gxA**2 + gxB**2 gx_total = gxL**2 + gxA**2 + gxB**2
gy_total = gyL**2 + gyA**2 + gyB**2 gy_total = gyL**2 + gyA**2 + gyB**2
magnitude = np.sqrt(gx_total + gy_total) magnitude = np.sqrt(gx_total + gy_total)
# Normalize and threshold
magnitude = cv2.normalize(magnitude, None, 0, 255, cv2.NORM_MINMAX) magnitude = cv2.normalize(magnitude, None, 0, 255, cv2.NORM_MINMAX)
edges = (magnitude > threshold).astype(np.uint8) * 255 edges = (magnitude > threshold).astype(np.uint8) * 255
return edges, magnitude
return edges
# === Step 2: Extract Vertical Profile ===
def extract_vertical_profile(image, center_x, center_y, length=21):
half_len = length // 2
y_range = np.clip(np.arange(center_y - half_len, center_y + half_len + 1), 0, image.shape[0] - 1)
profile = image[y_range, center_x].astype(np.float64)
profile -= profile.min()
if profile.max() > 0:
profile /= profile.max()
return profile, y_range - center_y # profile, x-axis
# === Step 3: Fit Gaussian ===
def gaussian(x, amp, mu, sigma):
return amp * np.exp(-(x - mu)**2 / (2 * sigma**2))
def fit_gaussian(profile, x_vals):
p0 = [1.0, 0.0, 2.0] # initial guess: amp, mu, sigma
popt, _ = curve_fit(gaussian, x_vals, profile, p0=p0)
return popt # amp, mu, sigma
# === Step 4: Create Gaussian Kernel ===
def create_gaussian_kernel(sigma):
ksize = int(sigma * 6) | 1 # ensure odd size
kernel_1d = cv2.getGaussianKernel(ksize, sigma)
kernel_2d = kernel_1d @ kernel_1d.T
return kernel_2d
def kernel_detection(blurred, mask, edge_threshold=30, profile_length=21):
edges, gradient_mag = color_edge_detection(blurred, threshold=edge_threshold)
edges = cv2.bitwise_and(edges, edges, mask=mask)
# show(edges)
# Find central edge pixel
y_idxs, x_idxs = np.where(edges > 0)
if len(x_idxs) == 0:
raise RuntimeError("No edges found.")
idx = len(x_idxs) // 2
cx, cy = x_idxs[idx], y_idxs[idx]
gray = cv2.cvtColor(blurred, cv2.COLOR_BGR2GRAY)
profile, x_vals = extract_vertical_profile(gray, cx, cy, length=profile_length)
popt = fit_gaussian(profile, x_vals)
amp, mu, sigma = popt
print(f"Estimated Gaussian sigma: {sigma:.2f}")
kernel = create_gaussian_kernel(sigma)
print(kernel)
return kernel / kernel.sum()
def deconvolution(image_file): def deconvolution(image_file, edge_threshold=30, profile_length=21):
blurred = cv2.imread(image_file) blurred = cv2.imread(image_file)
# blurred = cv2.resize(blurred, (200, 200), interpolation= cv2.INTER_LINEAR)
mask = get_mask(image_file) mask = get_mask(image_file)
edges = color_edge_detection(blurred, threshold=8) kernel = kernel_detection(blurred, mask, edge_threshold=edge_threshold, profile_length=profile_length)
# edges = cv2.bitwise_and(edges, edges, mask=mask)
show(edges) test_blurred = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.0
deconvolved = richardson_lucy(test_blurred, kernel, num_iter=30)
# Display results
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.title("Blurred Image")
plt.imshow(test_blurred, cmap='gray')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.title("Estimated Kernel")
plt.imshow(kernel, cmap='hot')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.title("Deconvolved Image")
plt.imshow(deconvolved, cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.show()
@ -179,4 +253,4 @@ if __name__ == "__main__":
img_file = "assets/real_test.jpg" img_file = "assets/real_test.jpg"
#demo("assets/omas.png") #demo("assets/omas.png")
deconvolution(img_file) deconvolution(img_file, edge_threshold=10)