From b4c7512a737663bd7378bf9a1993a26cfcbbcd12 Mon Sep 17 00:00:00 2001 From: Hazel Noack Date: Wed, 7 May 2025 12:02:54 +0200 Subject: [PATCH] feat: detect color edges --- deblur/deblur_2d.py | 89 ++++++++++++++++++--------------------------- 1 file changed, 36 insertions(+), 53 deletions(-) diff --git a/deblur/deblur_2d.py b/deblur/deblur_2d.py index 3aca925..ed40b88 100644 --- a/deblur/deblur_2d.py +++ b/deblur/deblur_2d.py @@ -5,6 +5,8 @@ from scipy.sparse.linalg import spsolve import cv2 import matplotlib.pyplot as plt from pathlib import Path +from scipy.ndimage import correlate + """ https://setosa.io/ev/image-kernels/ @@ -130,67 +132,48 @@ def get_mask(image_file): +def color_edge_detection(img, threshold=30): + # Load image + img_lab = cv2.cvtColor(img, cv2.COLOR_BGR2Lab) + + # Split Lab channels + L, A, B = cv2.split(img_lab) + + # Compute gradients using Sobel for each channel + def gradient_magnitude(channel): + gx = cv2.Sobel(channel, cv2.CV_64F, 1, 0, ksize=3) + gy = cv2.Sobel(channel, cv2.CV_64F, 0, 1, ksize=3) + return gx, gy + + gxL, gyL = gradient_magnitude(L) + gxA, gyA = gradient_magnitude(A) + gxB, gyB = gradient_magnitude(B) + + # Combine gradients across channels + gx_total = gxL**2 + gxA**2 + gxB**2 + gy_total = gyL**2 + gyA**2 + gyB**2 + magnitude = np.sqrt(gx_total + gy_total) + + # Normalize and threshold + magnitude = cv2.normalize(magnitude, None, 0, 255, cv2.NORM_MINMAX) + edges = (magnitude > threshold).astype(np.uint8) * 255 + + return edges + + def deconvolution(image_file): - image = cv2.imread(image_file, 0) - # image = cv2.resize(image, (200, 200), interpolation= cv2.INTER_LINEAR) + blurred = cv2.imread(image_file) + # blurred = cv2.resize(blurred, (200, 200), interpolation= cv2.INTER_LINEAR) mask = get_mask(image_file) - - # Define 2D image and kernel + edges = color_edge_detection(blurred, threshold=8) + # edges = cv2.bitwise_and(edges, edges, mask=mask) + show(edges) - kernel = np.array([ - [1, 1, 1], - [1, 1, 1], - [1, 1, 1] - ], dtype=np.float32) - kernel /= kernel.sum() # Normalize - print(kernel) - return - - # Perform 2D convolution (blurring) - - h, w = image.shape - kh, kw = kernel.shape - pad_h, pad_w = kh // 2, kw // 2 - - - show(image) - - print("Original image:\n", image) - print("\nBlurred image:\n", image) - - print("\nBuilding linear system for deconvolution...") - - # Step 2: Build sparse matrix A - N = h * w - A = lil_matrix((N, N), dtype=np.float32) - b = image.flatten() - - def index(y, x): - return y * w + x - - for y in range(h): - for x in range(w): - row_idx = index(y, x) - for ky in range(kh): - for kx in range(kw): - iy = y + ky - pad_h - ix = x + kx - pad_w - if 0 <= iy < h and 0 <= ix < w: - col_idx = index(iy, ix) - A[row_idx, col_idx] += kernel[ky, kx] - - # Step 3: Solve the sparse system A * x = b - x = spsolve(A.tocsr(), b) - deblurred = x.reshape((h, w)) - - print("\nDeblurred image:\n", np.round(deblurred, 2)) - - show(deblurred) if __name__ == "__main__": img_file = "assets/real_test.jpg"