feat: detect color edges

This commit is contained in:
Hazel Noack 2025-05-07 12:02:54 +02:00
parent ed650dcc5d
commit b4c7512a73

View File

@ -5,6 +5,8 @@ from scipy.sparse.linalg import spsolve
import cv2 import cv2
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from pathlib import Path from pathlib import Path
from scipy.ndimage import correlate
""" """
https://setosa.io/ev/image-kernels/ https://setosa.io/ev/image-kernels/
@ -130,67 +132,48 @@ def get_mask(image_file):
def color_edge_detection(img, threshold=30):
# Load image
img_lab = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)
# Split Lab channels
L, A, B = cv2.split(img_lab)
# Compute gradients using Sobel for each channel
def gradient_magnitude(channel):
gx = cv2.Sobel(channel, cv2.CV_64F, 1, 0, ksize=3)
gy = cv2.Sobel(channel, cv2.CV_64F, 0, 1, ksize=3)
return gx, gy
gxL, gyL = gradient_magnitude(L)
gxA, gyA = gradient_magnitude(A)
gxB, gyB = gradient_magnitude(B)
# Combine gradients across channels
gx_total = gxL**2 + gxA**2 + gxB**2
gy_total = gyL**2 + gyA**2 + gyB**2
magnitude = np.sqrt(gx_total + gy_total)
# Normalize and threshold
magnitude = cv2.normalize(magnitude, None, 0, 255, cv2.NORM_MINMAX)
edges = (magnitude > threshold).astype(np.uint8) * 255
return edges
def deconvolution(image_file): def deconvolution(image_file):
image = cv2.imread(image_file, 0) blurred = cv2.imread(image_file)
# image = cv2.resize(image, (200, 200), interpolation= cv2.INTER_LINEAR) # blurred = cv2.resize(blurred, (200, 200), interpolation= cv2.INTER_LINEAR)
mask = get_mask(image_file) mask = get_mask(image_file)
edges = color_edge_detection(blurred, threshold=8)
# Define 2D image and kernel # edges = cv2.bitwise_and(edges, edges, mask=mask)
show(edges)
kernel = np.array([
[1, 1, 1],
[1, 1, 1],
[1, 1, 1]
], dtype=np.float32)
kernel /= kernel.sum() # Normalize
print(kernel)
return
# Perform 2D convolution (blurring)
h, w = image.shape
kh, kw = kernel.shape
pad_h, pad_w = kh // 2, kw // 2
show(image)
print("Original image:\n", image)
print("\nBlurred image:\n", image)
print("\nBuilding linear system for deconvolution...")
# Step 2: Build sparse matrix A
N = h * w
A = lil_matrix((N, N), dtype=np.float32)
b = image.flatten()
def index(y, x):
return y * w + x
for y in range(h):
for x in range(w):
row_idx = index(y, x)
for ky in range(kh):
for kx in range(kw):
iy = y + ky - pad_h
ix = x + kx - pad_w
if 0 <= iy < h and 0 <= ix < w:
col_idx = index(iy, ix)
A[row_idx, col_idx] += kernel[ky, kx]
# Step 3: Solve the sparse system A * x = b
x = spsolve(A.tocsr(), b)
deblurred = x.reshape((h, w))
print("\nDeblurred image:\n", np.round(deblurred, 2))
show(deblurred)
if __name__ == "__main__": if __name__ == "__main__":
img_file = "assets/real_test.jpg" img_file = "assets/real_test.jpg"