原文:https://ptorch.com/news/73.html
由Eric Wong提供的PyTorch C
扩展程序包,用于执行批量的2D CuFFT
转换
这个包在PyPi
上。使用pip install pytorch-fft
即可安装
从pytorch_fft.fft
模块中,您可以使用以下函数执行前向和后向FFT
转换(复杂到复杂)
fft
和ifft
一维变换fft2
和ifft2
2D
转换fft3
和ifft3
3D
转换从同一个模块中,还可以使用以下方法实现复杂/复杂到实际的FFT转换
rfft
和irfft
一维变换rfft2
和irfft2
2D
转换rfft3
和irfft3
3D
转换对于d-D
变换,需要输入张量具有> =(d + 1)
尺寸(N1 X ... X NK X M1 X ... X MD)
,其中n1 x ... x nk
是批处理FFT
变换,并且m1 x ... x md
是尺寸的 d-D
变换。d
必须是从1到3的数字。
最后,该模块包含以下帮助函数,您可能会觉得有用
reverse(X, group_size=1)
颠倒张量的元素,并返回一个新的张量的结果。请注意,PyTorch目前不支持负面切片,请参阅此 问题。如果提供了一个组的大小,这些元素将在这个大小的组中反转。expand(X, imag=False, odd=True)
采用实际的二维或三维FFT的张量输出,并用其冗余条目进行扩展,以匹配复数FFT的输出。对于autograd
支持,请在pytorch_fft.fft.autograd
模块中使用以下功能 :
Fft
和Ifft
一维变换Fft2d
和Ifft2d
2D
转换Fft3d
和Ifft3d
3D
转换# Example that does a batch of three 2D transformations of size 4 by 5.
import torch
import pytorch_fft.fft as fft
A_real, A_imag = torch.randn(3,4,5).cuda(), torch.zeros(3,4,5).cuda()
B_real, B_imag = fft.fft2(A_real, A_imag)
fft.ifft2(B_real, B_imag) # equals (A, zeros)
B_real, B_imag = fft.rfft2(A) # is a truncated version which omits
# redundant entries
reverse(torch.arange(0,6)) # outputs [5,4,3,2,1,0]
reverse(torch.arange(0,6), 2) # outputs [4,5,2,3,0,1]
expand(B_real) # is equivalent to fft.fft2(A, zeros)[0]
expand(B_imag, imag=True) # is equivalent to fft.fft2(A, zeros)[1]
# Example that uses the autograd for 2D fft:
import torch
from torch.autograd import Variable
import pytorch_fft.fft.autograd as fft
import numpy as np
f = fft.Fft2d()
invf= fft.Ifft2d()
fx, fy = (Variable(torch.arange(0,100).view((1,1,10,10)).cuda(), requires_grad=True),
Variable(torch.zeros(1, 1, 10, 10).cuda(),requires_grad=True))
k1,k2 = f(fx,fy)
z = k1.sum() + k2.sum()
z.backward()
print(fx.grad, fy.grad)
这个遵循NumPy
的语义和行为,就比如ifft2(fft2(x)) = x
。请注意,用于反向FFT
的CuFFT
语义只会翻转变换的符号,但它不是真正的逆。
同样,真正的complex / complex
到真正的变体也遵循NumPy
的语义和行为。在1D
情况下,这意味着对于大小的输入N
,它返回大小的输出N//2+1
(它省略了多余的条目,请参阅Numpy
文档)
pytorch_fft.fft
模块中的函数不实现PyTorch autograd Function
,并且在语义和功能上都与它们的numpy
等价。
Autograd
功能在pytorch_fft.fft.autograd
模块中。
pytorch_fft/src
:C源代码pytorch_fft/fft
:Python便捷包装build.py
:编译文件test.py
:测试NumPy FFT
和Autograd
检查如果您有任何问题或功能要求,请提交问题或发送PR。