feat: speed up the code by huge difference

This commit is contained in:
Hazel Noack 2025-05-05 16:44:47 +02:00
parent f23fd1cdb3
commit 8d6eecaf78

View File

@ -1,5 +1,7 @@
import numpy as np import numpy as np
from scipy.signal import convolve2d from scipy.signal import convolve2d
from scipy.sparse import lil_matrix
from scipy.sparse.linalg import spsolve
import cv2 import cv2
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -25,6 +27,11 @@ print(kernel)
# Perform 2D convolution (blurring) # Perform 2D convolution (blurring)
blurred = convolve2d(image, kernel, mode="same", boundary="fill", fillvalue=0) blurred = convolve2d(image, kernel, mode="same", boundary="fill", fillvalue=0)
h, w = image.shape
kh, kw = kernel.shape
pad_h, pad_w = kh // 2, kw // 2
show(image) show(image)
show(blurred) show(blurred)
@ -33,35 +40,28 @@ print("\nBlurred image:\n", blurred)
print("\nBuilding linear system for deconvolution...") print("\nBuilding linear system for deconvolution...")
# Image size # Step 2: Build sparse matrix A
h, w = image.shape N = h * w
kh, kw = kernel.shape A = lil_matrix((N, N), dtype=np.float32)
pad_h, pad_w = kh // 2, kw // 2 b = blurred.flatten()
# Build matrix A and vector b for Ax = b def index(y, x):
A = [] return y * w + x
b = []
for y in range(h): for y in range(h):
for x in range(w): for x in range(w):
row = np.zeros((h, w), dtype=np.float32) row_idx = index(y, x)
for ky in range(kh): for ky in range(kh):
for kx in range(kw): for kx in range(kw):
iy = y + ky - pad_h iy = y + ky - pad_h
ix = x + kx - pad_w ix = x + kx - pad_w
if 0 <= iy < h and 0 <= ix < w: if 0 <= iy < h and 0 <= ix < w:
row[iy, ix] += kernel[ky, kx] col_idx = index(iy, ix)
A[row_idx, col_idx] += kernel[ky, kx]
A.append(row.flatten()) # Step 3: Solve the sparse system A * x = b
b.append(blurred[y, x]) x = spsolve(A.tocsr(), b)
deblurred = x.reshape((h, w))
A = np.array(A)
b = np.array(b)
# Solve for the deblurred image
deblurred_flat = np.linalg.solve(A, b)
deblurred = deblurred_flat.reshape((h, w))
print("\nDeblurred image:\n", np.round(deblurred, 2)) print("\nDeblurred image:\n", np.round(deblurred, 2))