pytorch_fft

https://github.com/locuslab/pytorch_fft


A PyTorch wrapper for CUDA FFTs License

A package that provides a PyTorch C extension for performing batches of 2D CuFFT transformations, by Eric Wong

Installation

This package is on PyPi. Install with pip install pytorch-fft.

Usage

  • From the pytorch_fft.fft module, you can use fft2 and ifft2 to do the forward and backward FFT transformations.
  • The input tensors are required to have >= 3 dimensions (n1 x ... x nk x row x col) where n1 x ... x nk is the batch of FFT transformations, and row x col are the dimension of each transformation.
import torch
import pytorch_fft.fft as fft

A_real, A_imag = torch.randn(3,4,5).cuda(), torch.zeros(3,4,5).cuda()
B_real, B_imag = fft.fft2(A_real, A_imag)
fft.ifft2(B_real, B_imag) # equals (A_real, A_imag)

Notes

  • This follows NumPy semantics, so ifft2(fft2(x)) = x. Note that CuFFT semantics for inverse FFT only flip the sign of the transform, but it is not a true inverse.
  • This function is NOT a PyTorch autograd Function, and as a result is not backprop-able. What this package allows you to do is call CuFFT on PyTorch Tensors.
  • The code currently only implements batched 2D transformation, for Complex to Complex transformations. If you require a different number of dimensions, the source code can be easily extended.

Repository contents

  • pytorch_fft/src: C source code
  • pytorch_fft/fft: Python convenience wrapper
  • build.py: compilation file
  • test.py: tests against NumPy FFTs

Issues and Contributions

If you have any issues or feature requests, file an issue or send in a PR.


你可能感兴趣的:(pytorch)