pytorch checkpoint 函数的坑

实验 pytorch 版本1.0.1
pytorch 的 checkpoint 是一个可以用时间换空间的技术,很多情况下可以轻松实现 batch_size 翻倍的效果


checkpoint 的输入需要requires_grad为True,不然在反向传播时不会计算内部梯度
简单让输入的requires_grad为True并且节省显存的办法

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class Net(nn.Module):
    def __init__():
        # 初始化。。。
        
    def forward(self, x):
        # 注意这行
        x = x + torch.zeros(1, dtype=x.dtype, device=x.device, requires_grad=True)
        y = checkpoint(seg_func, x)
        # 继续其他操作
        return y

验证实验

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

conv = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=1)

def seg1(x):
    return conv(x)

print('查看conv里面的梯度,一开始应当全为0或None')
print(conv.weight.grad)
print(conv.bias.grad)

x = torch.ones(1, 1, 1, 1)
y = seg1(x).mean() - 3
y.backward()

print('查看conv里面的梯度,现在应该不为0了')
print(conv.weight.grad)
print(conv.bias.grad)

print('清空conv的梯度,进行下一次测试')
conv.weight.grad.data.zero_()
conv.bias.grad.data.zero_()

print('查看conv里面的梯度,现在应该全为0了')
print(conv.weight.grad)
print(conv.bias.grad)

y = checkpoint(seg1, x).mean() - 3
try:
    print('此时应当会失败,y并不是计算图的一部分,因为x的requires_grad为False,checkpoint认为这段函数是不需要计算梯度的')
    y.backward()
except RuntimeError as e:
    print('backward果然抛出异常了')

print('查看conv里面的梯度,现在应该保持不变,仍然全为0了')
print(conv.weight.grad)
print(conv.bias.grad)

print('让输入的requires_grad为True,有俩个办法,一个是直接设定x的requires_grad为True,另外一个办法就是与另外一个requires_grad为True的常量合并操作')
print('这里使用的是合并操作,因为有时候并不能直接设置输入的requires_grad=True,另外我认为合并操作占用的显存更少,因为grad的shape跟原始变量是一样的'
      ',使用合并操作,额外无用的grad的size只有1,而设定输入的requires_grad为True,额外无用的grad的size跟输入一样大')
x2 = x + torch.zeros(1, dtype=x.dtype, device=x.device, requires_grad=True)
y = checkpoint(seg1, x2).mean() - 3
y.backward()
print('现在backward不会报错了')
print('查看conv里面的梯度,现在不为0了')
print(conv.weight.grad)
print(conv.bias.grad)
print('实验完成')

你可能感兴趣的:(checkpoint,pytorch,神经网络,python)