> * 调用 numpy 作为其实现的一部分
> * 作为其实现的一部分,它引发 SciPy
import torch
from torch.autograd import Function
它被恰当地命名为 BadFFTFunction
from numpy.fft import rfft2, irfft2
class BadFFTFunction(Function):
def forward(ctx, input):
numpy_input = input.detach().numpy()
result = abs(rfft2(numpy_input))
return input.new(result)
def backward(ctx, grad_output):
numpy_go = grad_output.numpy()
result = irfft2(numpy_go)
return grad_output.new(result)
# since this layer does not have any parameters, we can
# simply declare this as a function, rather than as an nn.Module class
def incorrect_fft(input):
return BadFFTFunction.apply(input)
input = torch.randn(8, 8, requires_grad=True)
result = incorrect_fft(input)
tensor([[ 3.4193, 6.2492, 1.4814, 3.9813, 14.7739],
[ 5.8865, 2.3533, 1.4765, 4.1547, 5.1887],
[ 6.9061, 4.7597, 5.9473, 11.8609, 10.0465],
[10.6724, 6.5964, 11.6951, 3.7400, 9.3114],
[ 4.9715, 5.7563, 10.5443, 4.6442, 2.3052],
[10.6724, 12.7221, 1.7242, 7.2647, 9.3114],
[ 6.9061, 8.1754, 4.0003, 3.1523, 10.0465],
[ 5.8865, 15.8094, 6.9266, 3.8533, 5.1887]],
tensor([[ 0.4963, 0.5220, -0.0559, 0.2241, 2.1238, 0.0324, -0.1219, 0.7318],
[ 1.0476, -0.1841, -0.2717, 0.5470, -1.3174, -0.6817, 1.0102, -0.2014],
[-1.1131, 1.0054, 0.3593, 0.6158, 0.5398, 2.3020, -0.1305, -0.8611],
[-0.4693, -1.3720, -0.8196, -1.4975, -0.4474, -0.1150, -0.3285, 1.8079],
[-0.2928, 1.4333, -0.2744, -0.9194, -0.2592, 0.4996, 0.7862, 0.3972],
[ 0.7595, 0.5625, -0.7585, -0.1439, -0.5243, -1.0789, 0.4915, 1.5880],
[-0.6971, 0.0267, 0.2316, -0.8939, -1.9865, -0.7424, -0.6252, 0.8415],
[-0.9989, 1.0916, 0.2223, 2.1130, -0.3831, 0.9612, -1.8703, 0.4848]],
from numpy import flip
import numpy as np
from scipy.signal import convolve2d, correlate2d
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter
class ScipyConv2dFunction(Function):
def forward(ctx, input, filter, bias):
# detach so we can cast to NumPy
input, filter, bias = input.detach(), filter.detach(), bias.detach()
result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
result += bias.numpy()
ctx.save_for_backward(input, filter, bias)
return torch.as_tensor(result, dtype=input.dtype)
def backward(ctx, grad_output):
grad_output = grad_output.detach()
input, filter, bias = ctx.saved_tensors
grad_output = grad_output.numpy()
grad_bias = np.sum(grad_output, keepdims=True)
grad_input = convolve2d(grad_output, filter.numpy(), mode='full')
# the previous line can be expressed equivalently as:
# grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1), mode='full')
grad_filter = correlate2d(input.numpy(), grad_output, mode='valid')
return torch.from_numpy(grad_input), torch.from_numpy(grad_filter).to(torch.float), torch.from_numpy(grad_bias).to(torch.float)
class ScipyConv2d(Module):
def __init__(self, filter_width, filter_height):
super(ScipyConv2d, self).__init__()
self.filter = Parameter(torch.randn(filter_width, filter_height))
self.bias = Parameter(torch.randn(1, 1))
def forward(self, input):
return ScipyConv2dFunction.apply(input, self.filter, self.bias)
module = ScipyConv2d(3, 3)
print("Filter and bias: ", list(module.parameters()))
input = torch.randn(10, 10, requires_grad=True)
output = module(input)
print("Output from the convolution: ", output)
output.backward(torch.randn(8, 8))
print("Gradient for the input map: ", input.grad)
Filter and bias: [Parameter containing:
tensor([[-0.4935, -1.9784, -0.7520],
[ 0.0575, -1.3029, -1.9318],
[ 1.1692, 0.3187, -0.3044]], requires_grad=True), Parameter containing:
tensor([[-0.9122]], requires_grad=True)]
Output from the convolution: tensor([[ 1.9228e+00, 5.1708e+00, -1.7722e-01, 6.0846e-01, 3.1403e+00,
2.8270e+00, -2.7449e+00, -7.5212e+00],
[ 9.2997e-01, 1.4772e-01, -1.7859e+00, -4.8096e+00, -3.2894e+00,
2.9466e+00, 2.2201e+00, -1.7857e+00],
[-3.4693e+00, 1.0293e-01, -1.6728e+00, -6.0165e+00, -6.5472e+00,
-8.0421e-01, -2.2059e+00, -7.3037e+00],
[-7.1793e+00, 2.5055e-01, 1.0368e+00, -1.6319e+00, -8.3195e+00,
-4.6471e+00, -1.6651e+00, -3.5958e+00],
[-1.3189e+00, 1.9218e+00, 4.4448e+00, 2.2613e+00, -4.9712e+00,
-7.1531e+00, -7.3305e-01, -1.7002e+00],
[ 8.7136e-01, 6.6128e-01, 2.9627e+00, 1.1253e+00, -4.1352e+00,
-2.5405e+00, 2.5132e+00, 5.8881e+00],
[ 1.4052e+00, -2.3034e+00, -8.8434e-01, -5.1262e-02, -1.0159e+01,
-5.9782e+00, -2.7782e-01, 1.6121e+00],
[ 2.0402e+00, 1.7031e-03, 1.5932e+00, 4.6277e+00, -8.4549e+00,
-1.4017e+01, -6.9038e+00, -2.8451e+00]],
Gradient for the input map: tensor([[ 3.5673e-01, 1.5783e+00, 5.5266e-01, -2.2223e+00, -1.1648e+00,
1.3576e-01, -9.2908e-01, -9.5302e-01, -1.0746e-01, 4.7559e-02],
[-8.7051e-01, -2.4345e+00, 2.0736e-01, -1.8105e+00, -2.6859e+00,
-1.2512e+00, -3.8021e-01, -6.9782e-01, 2.8626e+00, 1.4599e+00],
[-1.1194e+00, -4.4828e+00, -2.9077e+00, 1.2989e+00, 7.7022e-01,
1.9482e+00, 8.1654e-01, 4.3522e+00, 4.5537e+00, 3.8784e+00],
[ 2.0067e+00, 3.9783e-01, 5.8589e-01, -1.2962e+00, -1.7774e+00,
2.7983e-01, 2.0063e+00, -5.3575e-01, 1.9437e+00, 9.2401e-01],
[ 9.6117e-01, 1.1666e+00, 2.1162e+00, 2.9779e+00, -4.0048e+00,
-2.1271e-01, 3.6012e-01, -1.5548e-03, -2.2730e+00, -1.9148e+00],
[-5.5450e-01, -3.5749e+00, 2.1108e-01, 2.7341e+00, -5.4415e-01,
-1.0326e-01, 5.2842e+00, 6.0280e+00, 1.7129e+00, -1.0743e+00],
[-4.5225e-01, -3.2305e+00, -3.9962e+00, -1.4418e+00, -6.6439e+00,
-6.6198e+00, -4.7138e+00, 2.4043e-01, 3.5485e+00, -2.6767e-01],
[ 1.2502e+00, 1.5484e-01, 4.7037e+00, 4.7502e+00, -2.7102e-01,
-5.2364e+00, -7.6462e+00, -1.5124e+00, 3.1666e+00, 1.0458e+00],
[ 7.6941e-01, 1.0242e-01, 3.4887e+00, 6.9709e+00, 4.7607e+00,
1.1723e+00, -2.5379e+00, 1.0167e+00, 3.5257e+00, 2.5870e+00],
[ 1.8773e-01, -2.9968e+00, -2.2132e+00, -2.2993e-01, 2.0811e+00,
5.9679e-01, -1.6988e+00, -1.8538e+00, -1.1920e-01, 4.0583e-01]])
from torch.autograd.gradcheck import gradcheck
moduleConv = ScipyConv2d(3, 3)
input = [torch.randn(20, 20, dtype=torch.double, requires_grad=True)]
test = gradcheck(moduleConv, input, eps=1e-6, atol=1e-4)
print("Are the gradients correct: ", test)
Are the gradients correct: True