本文使用numpy和pytorch分别操作,对图像进行fft变换和ifft变换,以实现图像的频域分析。pytorch的fft功能在1.7.1版之后才完善,支持CUDA和autograd,可以加入到神经网络中实现一些有趣的操作。
图像的频率是指像素值沿横轴和纵轴方向变化的快慢,注意这里的频率是借用了时间序列信号中的‘频率’概念,但指的是空间上的频率。从振幅和相位两个分量分析,图像的振幅包含了颜色、纹理信息,相位包含了轮廓、形状信息。从高频和低频分量看,高频分量代表了图像像素值发生突变的物体边缘或者嘈杂区域,低频分量则是颜色较均匀和平坦的部分,高频分量虽然占据了较多的数据量,但高频分量人眼不易见,所以可以去掉大部分的高频分量而不影响人眼观感,这也是jpg等一些图像压缩方法的主要原理。
np.fft.fft() 一维傅里叶变换
np.fft.fft2() 二维傅里叶变换
np.fft.fftn() n维傅里叶变换
np.fft.fftshift() 把转换后的频域图的低频部分放到图像中间,仅起到便于观察的作用
np.fft.ifft() 一维傅里叶逆变换
np.fft.ifft2() 二维傅里叶逆变换
np.fft.ifftn() n维傅里叶逆变换
cv2.dft() 图像傅里叶变换 ,注意:数据格式应为float32
cv2.idft() 图像傅里叶逆变换
注意:np.fft.fft2()进行图像傅里叶变换时,数据应为非负,否则用np.fft.ifft2()无法还原。
图像经过二维傅里叶变换到频域后每个像素点都是一个包含实部和虚部的复数,求每个像素复数的幅值和相位就可以得到图像的振幅和相位。通常,图像的振幅包含了图像全局的信息,也就是纹理、色彩等信息,相位则包含了图像局部信息,也就是轮廓、形状等信息。这一点直接观察振幅和相位图看不出来,可以通过改变其中一个分量,然后再还原的方式观察每个分量的作用。
import numpy as np
import cv2
import matplotlib.pyplot as plt
file_path = 'birds.JPEG'
img = cv2.imread(file_path)[:,:,::-1] #cv2默认是BGR通道顺序,这里调整到RGB
img = cv2.resize(img,(500,500))
fre = np.fft.fft2(img,axes=(0,1)) #变换得到的频域图数据是复数组成的
fre_m = np.abs(fre) #幅度谱,求模得到
fre_p = np.angle(fre) #相位谱,求相角得到
#把振幅设为常数
constant = fre_m.mean()
fre_ = constant * np.e**(1j*fre_p) #把幅度谱和相位谱再合并为复数形式的频域图数据
img_onlyphase = np.abs(np.fft.ifft2(fre_,axes=(0,1))) #还原为空间域图像
#把相位设为常数
constant = fre_p.mean()
fre_ = fre_m * np.e**(1j*constant)
img_onlymagnitude = np.abs(np.fft.ifft2(fre_,axes=(0,1)))
plt.figure()
plt.imshow(img.astype('uint8'))
plt.figure()
plt.imshow(img_onlyphase.astype('uint8'))
plt.figure()
plt.imshow(img_onlymagnitude.astype('uint8'))
import numpy as np
img = ... #二维numpy数组,非负
fre = np.fft.fft2(img) #变换得到的频域图数据是复数组成的
fre_shift = np.fft.fftshift(fre)#把低频数据移到频域图的中央,仅仅是为了便于观察
rows, cols = img.shape
crow,ccol = int(rows/2), int(cols/2) #中心位置
#下面把中间的低频部分去掉,所以是高通滤波,如果没有进行移频,就是低通滤波了
mask = np.ones((rows, cols))
mask[crow-10:crow+10, ccol-10:ccol+10] = 0
f = fre_shift * mask
img_ = np.abs(np.fft.ifft2(f))
#下面这样也是低通滤波
mask = np.zeros((rows, cols))
mask[crow-10:crow+10, ccol-10:ccol+10] = 1
f = fre_shift * mask
img_ = np.abs(np.fft.ifft2(f))
import numpy as np
import cv2
file_path = '../yangchaoyue.png'
img = cv2.imread(file_path)
h,w = img.shape[:2]
#生成低通和高通滤波器
lpf = np.zeros((h,w,3))
R = (h+w)//8 #或其他
for x in range(w):
for y in range(h):
if ((x-(w-1)/2)**2 + (y-(h-1)/2)**2) < (R**2):
lpf[y,x,:] = 1
hpf = 1-lpf
freq = np.fft.fft2(img,axes=(0,1))
freq = np.fft.fftshift(freq)
lf = freq * lpf
hf = freq * hpf
#生成低频分量图
img_l = np.abs(np.fft.ifft2(lf,axes=(0,1)))
img_l = np.clip(img_l,0,255) #会产生一些过大值需要截断
img_l = img_l.astype('uint8')
cv2.imwrite(file_path[:-4]+'_LPF.png',img_l)
#生成高频分量图
img_h = np.abs(np.fft.ifft2(hf,axes=(0,1)))
img_h = np.clip(img_h,0,255) #似乎一般不会超,加上保险一些
img_h = img_h.astype('uint8')
cv2.imwrite(file_path[:-4]+'_HPF.png',img_h)
#画出频谱图
freq_view = np.log(1 +np.abs(freq))
freq_view = (freq_view - freq_view.min()) / (freq_view.max() - freq_view.min()) * 255
freq_view = freq_view.astype('uint8').copy()
cv2.circle(freq_view,((w-1)//2,(h-1)//2),R,(255,255,255),2)
cv2.imwrite(file_path[:-4]+'_Freq.png',freq_view)
在pytorch1.7.1版以后才完善了fft功能,之前版本也有torch.fft但不支持反向传播,现已废弃。为了不和旧版混淆,在使用时需要显式导入此包,即
import torch.fft as fft
主要函数有:
fft.fft() 一维傅里叶变换
fft.ifft() 一维傅里叶逆变换
fft.fftn() n维傅里叶变换
fft.ifftn() n维傅里叶逆变换
1.8.0以后版本中才有fft.fft2()和fft.ifft2()来专门做二维变换,但1.8.0的GPU版我怎么也装不上,只能装CPU版。另外在1.8.0版中有fft.fftshift()函数用来把频域中的低频部分放到中央,这是做图像频域变换的通用操作,但在1.7.1版中没有。好在我们使用一些变通操作可以在1.7.1版中解决上述两个问题:一是用fftn和ifftn指定维度来实现二维;二是用torch.roll()操作来实现fftshift功能,即(对于freq.shape=(C,H,W)):
freq = torch.roll(freq,(H//2,W//2),dim=(1,2)) 。
具体使用参见以下代码:
import torch.fft as fft
......(其他代码省略)
lpf = torch.zeros((h,w))
R = (h+w)//8 #或其他
for x in range(w):
for y in range(h):
if ((x-(w-1)/2)**2 + (y-(h-1)/2)**2) < (R**2):
lpf[y,x] = 1
hpf = 1-lpf
hpf, lpf = hpf.cuda(), lpf.cuda()
for X,y in dataloader:
X = X.cuda() # X.shape: (b, c, h, w)
X = unnormalize(X) #注意fft只能处理非负数,通常X是标准化到正态分布的,这里需要把X再变换到[0,1]区间,unnormalize = lambda x: x*std + mu
f = fft.fftn(X,dim=(2,3))
f = torch.roll(f,(h//2,w//2),dims=(2,3)) #移频操作,把低频放到中央
f_l = f * lpf
f_h = f * hpf
X_l = torch.abs(fft.ifftn(f_l,dim=(2,3)))
X_h = torch.abs(fft.ifftn(f_h,dim=(2,3)))