feat: detect color edges

This commit is contained in:
Hazel Noack 2025-05-07 12:02:54 +02:00
parent ed650dcc5d
commit b4c7512a73

View File

@ -5,6 +5,8 @@ from scipy.sparse.linalg import spsolve
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from scipy.ndimage import correlate
"""
https://setosa.io/ev/image-kernels/
@ -130,67 +132,48 @@ def get_mask(image_file):
def color_edge_detection(img, threshold=30):
# Load image
img_lab = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)
# Split Lab channels
L, A, B = cv2.split(img_lab)
# Compute gradients using Sobel for each channel
def gradient_magnitude(channel):
gx = cv2.Sobel(channel, cv2.CV_64F, 1, 0, ksize=3)
gy = cv2.Sobel(channel, cv2.CV_64F, 0, 1, ksize=3)
return gx, gy
gxL, gyL = gradient_magnitude(L)
gxA, gyA = gradient_magnitude(A)
gxB, gyB = gradient_magnitude(B)
# Combine gradients across channels
gx_total = gxL**2 + gxA**2 + gxB**2
gy_total = gyL**2 + gyA**2 + gyB**2
magnitude = np.sqrt(gx_total + gy_total)
# Normalize and threshold
magnitude = cv2.normalize(magnitude, None, 0, 255, cv2.NORM_MINMAX)
edges = (magnitude > threshold).astype(np.uint8) * 255
return edges
def deconvolution(image_file):
image = cv2.imread(image_file, 0)
# image = cv2.resize(image, (200, 200), interpolation= cv2.INTER_LINEAR)
blurred = cv2.imread(image_file)
# blurred = cv2.resize(blurred, (200, 200), interpolation= cv2.INTER_LINEAR)
mask = get_mask(image_file)
# Define 2D image and kernel
edges = color_edge_detection(blurred, threshold=8)
# edges = cv2.bitwise_and(edges, edges, mask=mask)
show(edges)
kernel = np.array([
[1, 1, 1],
[1, 1, 1],
[1, 1, 1]
], dtype=np.float32)
kernel /= kernel.sum() # Normalize
print(kernel)
return
# Perform 2D convolution (blurring)
h, w = image.shape
kh, kw = kernel.shape
pad_h, pad_w = kh // 2, kw // 2
show(image)
print("Original image:\n", image)
print("\nBlurred image:\n", image)
print("\nBuilding linear system for deconvolution...")
# Step 2: Build sparse matrix A
N = h * w
A = lil_matrix((N, N), dtype=np.float32)
b = image.flatten()
def index(y, x):
return y * w + x
for y in range(h):
for x in range(w):
row_idx = index(y, x)
for ky in range(kh):
for kx in range(kw):
iy = y + ky - pad_h
ix = x + kx - pad_w
if 0 <= iy < h and 0 <= ix < w:
col_idx = index(iy, ix)
A[row_idx, col_idx] += kernel[ky, kx]
# Step 3: Solve the sparse system A * x = b
x = spsolve(A.tocsr(), b)
deblurred = x.reshape((h, w))
print("\nDeblurred image:\n", np.round(deblurred, 2))
show(deblurred)
if __name__ == "__main__":
img_file = "assets/real_test.jpg"