用于CUDA FFT的PyTorch包装器pytorch-fft

用于CUDA FFT的PyTorch包装器pytorch-fft

原文:https://ptorch.com/news/73.html

由Eric Wong提供的PyTorch C扩展程序包,用于执行批量的2D CuFFT转换

安装

这个包在PyPi上。使用pip install pytorch-fft即可安装

用法

pytorch_fft.fft模块中,您可以使用以下函数执行前向和后向FFT转换(复杂到复杂)

  • fftifft一维变换
  • fft2ifft2 2D转换
  • fft3ifft3 3D转换

从同一个模块中,还可以使用以下方法实现复杂/复杂到实际的FFT转换

  • rfftirfft一维变换
  • rfft2irfft2 2D转换
  • rfft3irfft3 3D转换

对于d-D变换,需要输入张量具有> =(d + 1)尺寸(N1 X ... X NK X M1 X ... X MD),其中n1 x ... x nk是批处理FFT变换,并且m1 x ... x md是尺寸的 d-D变换。d必须是从1到3的数字。

最后,该模块包含以下帮助函数,您可能会觉得有用

  • reverse(X, group_size=1)颠倒张量的元素,并返回一个新的张量的结果。请注意,PyTorch目前不支持负面切片,请参阅此 问题。如果提供了一个组的大小,这些元素将在这个大小的组中反转。
  • expand(X, imag=False, odd=True)采用实际的二维或三维FFT的张量输出,并用其冗余条目进行扩展,以匹配复数FFT的输出。

对于autograd支持,请在pytorch_fft.fft.autograd模块中使用以下功能 :

  • FftIfft一维变换
  • Fft2dIfft2d 2D转换
  • Fft3dIfft3d 3D转换
# Example that does a batch of three 2D transformations of size 4 by 5. 
import torch
import pytorch_fft.fft as fft

A_real, A_imag = torch.randn(3,4,5).cuda(), torch.zeros(3,4,5).cuda()
B_real, B_imag = fft.fft2(A_real, A_imag)
fft.ifft2(B_real, B_imag) # equals (A, zeros)

B_real, B_imag = fft.rfft2(A) # is a truncated version which omits
                                   # redundant entries

reverse(torch.arange(0,6)) # outputs [5,4,3,2,1,0]
reverse(torch.arange(0,6), 2) # outputs [4,5,2,3,0,1]

expand(B_real) # is equivalent to  fft.fft2(A, zeros)[0]
expand(B_imag, imag=True) # is equivalent to  fft.fft2(A, zeros)[1]
# Example that uses the autograd for 2D fft:
import torch
from torch.autograd import Variable
import pytorch_fft.fft.autograd as fft
import numpy as np

f = fft.Fft2d()
invf= fft.Ifft2d()

fx, fy = (Variable(torch.arange(0,100).view((1,1,10,10)).cuda(), requires_grad=True), 
          Variable(torch.zeros(1, 1, 10, 10).cuda(),requires_grad=True))
k1,k2 = f(fx,fy)
z = k1.sum() + k2.sum()
z.backward()
print(fx.grad, fy.grad)

笔记

这个遵循NumPy的语义和行为,就比如ifft2(fft2(x)) = x。请注意,用于反向FFTCuFFT语义只会翻转变换的符号,但它不是真正的逆。

同样,真正的complex / complex到真正的变体也遵循NumPy的语义和行为。在1D情况下,这意味着对于大小的输入N,它返回大小的输出N//2+1(它省略了多余的条目,请参阅Numpy文档)

pytorch_fft.fft模块中的函数不实现PyTorch autograd Function,并且在语义和功能上都与它们的numpy等价。

Autograd功能在pytorch_fft.fft.autograd模块中。

知识库内容

  • pytorch_fft/src:C源代码
  • pytorch_fft/fft:Python便捷包装
  • build.py:编译文件
  • test.py:测试NumPy FFTAutograd检查

问题和贡献

如果您有任何问题或功能要求,请提交问题或发送PR。

你可能感兴趣的:(用于CUDA FFT的PyTorch包装器pytorch-fft)