pytorch拓展——scipy

代码

# coding:utf-8
from numpy import flip
import torch
from torch.autograd import Function, Variable
import numpy as np
from scipy.signal import convolve2d, correlate2d
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter


class ScipyConv2dFunction(Function):
    @staticmethod
    def forward(ctx, input, filter, bias):
        # detach so we can cast to NumPy
        input, filter, bias = input.detach(), filter.detach(), bias.detach()
        result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
        result += bias.numpy()
        ctx.save_for_backward(input, filter, bias)
        return torch.as_tensor(result, dtype=input.dtype)

    @staticmethod
    def backward(ctx, grad_output):
        grad_output = grad_output.detach()
        input, filter, bias = ctx.saved_tensors
        grad_output = grad_output.numpy()
        grad_bias = np.sum(grad_output, keepdims=True)
        grad_input = convolve2d(grad_output, filter.numpy(), mode='full')
        # the previous line can be expressed equivalently as:
        # grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1), mode='full')
        grad_filter = correlate2d(input.numpy(), grad_output, mode='valid')
        return torch.from_numpy(grad_input), torch.from_numpy(grad_filter).to(torch.float), torch.from_numpy(grad_bias).to(torch.float)


class ScipyConv2d(Module):
    def __init__(self, filter_width, filter_height):
        super(ScipyConv2d, self).__init__()
        self.filter = Parameter(torch.randn(filter_width, filter_height))
        # self.filter = Variable(torch.randn(filter_width, filter_height), requires_grad=True)      # 和上面一句实际功能一样,但是不能用self.filter.parameters
        self.bias = Parameter(torch.randn(1, 1))

    def forward(self, input):
        return ScipyConv2dFunction.apply(input, self.filter, self.bias)


module = ScipyConv2d(3, 3)
print("Filter and bias: ", list(module.parameters()))
input = torch.randn(10, 10, requires_grad=True)
output = module(input)
print("Output from the convolution: ", output)
print('求梯度之前filter的梯度', module.filter.grad)

output.backward(torch.randn(8, 8))                  # gradient要和output具有一样的shape
print("Gradient for the input map: ", input.grad)

print('filter:', module.filter.grad)

解释:

  1. 前面torch.auto.Function负责具体的前向传播和反向传播,
  2. Module只有forward前向传播方法
  3. ScipyConv2dFunction.apply(input, self.filter, self.bias)这里使用了apply的方法。
    ScipyConv2dFunction()(input, self.filter, self.bias)是会报错的,这里是apply是python的classmethod
    @classmethod
    def apply(cls, *args, **kwargs): # real signature unknown
        pass

总结:采用apply调用的话,必须加上staticmethod,采用类方法调用Function时不需要加staticmethod

你可能感兴趣的:(pytorch)