pytorch中实现根据版本不同,使用的函数也不同,最直接的区别就是旧版fft后出现的是实数,而新版出来的是复数,这里进行一个记录:参考这个
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') x_res = torch.randn(1, 3, 256, 256).to(device) output_fft_old = torch.rfft(x_res, signal_ndim=2, normalized=False, onesided=False) print('output_fft_old:',output_fft_old.size()) x_real = output_fft_old[...,0] print('x_real:',x_real.size()) x_imaginary = output_fft_old[...,1] print('x_imaginary:',x_imaginary.size())
output_ifft_old = torch.irfft(output_fft_old , signal_ndim=2, normalized=False, onesided=False) print('output_ifft_old:',output_ifft_old.size())
--------------------------输出结果--------------------------
output_fft_old: torch.Size([1, 3, 256, 256, 2]) x_real: torch.Size([1, 3, 256, 256]) x_imaginary: torch.Size([1, 3, 256, 256]) output_ifft_old: torch.Size([1, 3, 256, 256])
--------------------------输出结果--------------------------
新版的rfft和fft都是用于一维输入,而我们的图像是二维,所以应该用rfft2和fft2。在fft2中,参数dim用来指定用于傅里叶变换的维度,默认(-2,-1),正好对应H、W两个维度。
我打印了结果,可以看出来此时fft结束后输出的是复数
打印实部,这个.real 在旧版里面测试是没有的
打印虚部
进行反变换
output_ifft_new = torch.fft.ifft2(torch.complex(output_new_2dim[..., 0], output_new_2dim[..., 1]), dim=(-2, -1)) # 如果运行了torch.stack()
output_ifft_new = torch.fft.ifft2(output_new_2dim , dim=(-2, -1))
或者直接获取实部和虚部
利用幅值和相位进行反变换
x_fft_res = x_fft_res.real 这个就是反变换后的结果
则在新版里面的frequency loss可以是:
train_out_J_fft = torch.fft.fft2(train_out_J, dim=(-2, -1)) train_out_J_fft = torch.stack((train_out_J_fft.real, train_out_J_fft.imag), -1) train_labels_fft = torch.fft.fft2(train_labels, dim=(-2, -1)) train_labels_fft = torch.stack((train_labels_fft.real, train_labels_fft.imag), -1) train_fft_loss = 1 * nn.L1Loss()(train_out_J_fft, train_labels_fft)