学习记录2:pytorch中FFT

pytorch中实现根据版本不同,使用的函数也不同,最直接的区别就是旧版fft后出现的是实数,而新版出来的是复数,这里进行一个记录:参考这个

1.pytorch旧版本(1.7之前)中有一个函数torch.rfft()

学习记录2:pytorch中FFT_第1张图片

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x_res = torch.randn(1, 3, 256, 256).to(device)
output_fft_old = torch.rfft(x_res, signal_ndim=2, normalized=False, onesided=False)
print('output_fft_old:',output_fft_old.size())
x_real = output_fft_old[...,0]
print('x_real:',x_real.size())
x_imaginary = output_fft_old[...,1]
print('x_imaginary:',x_imaginary.size())
output_ifft_old = torch.irfft(output_fft_old , signal_ndim=2, normalized=False, onesided=False)
print('output_ifft_old:',output_ifft_old.size())

--------------------------输出结果--------------------------

output_fft_old: torch.Size([1, 3, 256, 256, 2])
x_real: torch.Size([1, 3, 256, 256])
x_imaginary: torch.Size([1, 3, 256, 256])
output_ifft_old: torch.Size([1, 3, 256, 256])

--------------------------输出结果--------------------------

2.新版本(1.8、1.9)中被移除了,添加了torch.fft.rfft()

新版的rfft和fft都是用于一维输入,而我们的图像是二维,所以应该用rfft2和fft2。在fft2中,参数dim用来指定用于傅里叶变换的维度,默认(-2,-1),正好对应H、W两个维度。

学习记录2:pytorch中FFT_第2张图片

我打印了结果,可以看出来此时fft结束后输出的是复数

学习记录2:pytorch中FFT_第3张图片

打印实部,这个.real 在旧版里面测试是没有的

学习记录2:pytorch中FFT_第4张图片

 打印虚部

进行反变换

output_ifft_new = torch.fft.ifft2(torch.complex(output_new_2dim[..., 0], output_new_2dim[..., 1]), dim=(-2, -1))    # 如果运行了torch.stack()
output_ifft_new = torch.fft.ifft2(output_new_2dim , dim=(-2, -1)) 

或者直接获取实部和虚部

学习记录2:pytorch中FFT_第5张图片

 利用幅值和相位进行反变换

学习记录2:pytorch中FFT_第6张图片

 x_fft_res = x_fft_res.real  这个就是反变换后的结果

则在新版里面的frequency loss可以是:

train_out_J_fft = torch.fft.fft2(train_out_J, dim=(-2, -1))
train_out_J_fft = torch.stack((train_out_J_fft.real, train_out_J_fft.imag), -1)
train_labels_fft = torch.fft.fft2(train_labels, dim=(-2, -1))
train_labels_fft = torch.stack((train_labels_fft.real, train_labels_fft.imag), -1)

train_fft_loss = 1 * nn.L1Loss()(train_out_J_fft, train_labels_fft)

 

 

 

 

你可能感兴趣的:(torch,FFT,pytorch,学习,深度学习)