torch.fft.fft2.() 报错问题解决

运行别人的的源代码还报错,所以确定不是代码问题。

torch.fft.fft2.() 报错问题解决_第1张图片

问了别人,应该是fft函数对应的torch版本问题,torch1.8.0版本之后的才是

torch.fft.fft2

根据网上的总结自己改的

#旧版          新版
torch.rfft   torch.fft.fft2
torch.irfft  torch.fft.ifft2

还是报错了,函数中使用的参数定义应该也是不一样的

torch.fft.fft2.() 报错问题解决_第2张图片

原来不是版本问题,看到pytorch官网上只有这两种函数,所以猜想是不是我写错了

torch.fft.fft2.() 报错问题解决_第3张图片

torch.fft.fft2.() 报错问题解决_第4张图片

   解决办法,把源代码的torch.fft.fft2改成torch.fft.fftn就可以了

def D(x, Dh_DFT, Dv_DFT):
    x_DFT = torch.fft.fftn(x, dim=(-2,-1)).cuda()
    Dh_x = torch.fft.ifftn(Dh_DFT*x_DFT, dim=(-2,-1)).real
    Dv_x = torch.fft.ifftn(Dv_DFT*x_DFT, dim=(-2,-1)).real
    return Dh_x, Dv_x

然后代码正常运行啦,为了证明代码真的运行成功过,因为最近不知道怎么回事,又开始报和原来一样的错误了,找不到原因,但是至少成功过,那我的改的应该是对的。

解决方法来了:把torch改成np 就可以啦
torch.fft.fft2.() 报错问题解决_第5张图片

新错误哈哈哈,继续搞

搜索发现numpy形式的傅里叶变换中间参数需要对应有变化

fft.fft2(a, s=None, axes=(- 2, - 1), norm=None)

 即torch中的参数dim在np中对应axes

torch.fft.fft2.() 报错问题解决_第6张图片

 修改之后果然没有这个错误了,cuda应该是tensor形式中用到的。

观察代码中 D 的输入参数发现都是tensor形式的,所以决定把fft代码改为torch的继续调整。

方法1:需显式的引入torch.fft才行

import torch
改为
import torch.fft

报错改变,如下

torch.fft.fft2.() 报错问题解决_第7张图片

 于是根据以上思路,改成更详细的显式引入:torch.fft.fft2

报错了,没有这个模块........

 以上方法不行,就会去解决cuda()的问题,更改代码如下,由于调用cuda需要tensor数据,所以把定义的变量转化为tensor,虽然有点麻烦,但是挺有用的,程序里已经没有fft的报错了,至少代码往下进行了。

torch.fft.fft2.() 报错问题解决_第8张图片

你可能感兴趣的:(python,pytorch,深度学习)