实验 pytorch 版本1.0.1
pytorch 的 checkpoint 是一个可以用时间换空间的技术,很多情况下可以轻松实现 batch_size 翻倍的效果
坑
checkpoint 的输入需要requires_grad为True,不然在反向传播时不会计算内部梯度
简单让输入的requires_grad为True并且节省显存的办法
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
class Net(nn.Module):
def __init__():
# 初始化。。。
def forward(self, x):
# 注意这行
x = x + torch.zeros(1, dtype=x.dtype, device=x.device, requires_grad=True)
y = checkpoint(seg_func, x)
# 继续其他操作
return y
验证实验
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
conv = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=1)
def seg1(x):
return conv(x)
print('查看conv里面的梯度,一开始应当全为0或None')
print(conv.weight.grad)
print(conv.bias.grad)
x = torch.ones(1, 1, 1, 1)
y = seg1(x).mean() - 3
y.backward()
print('查看conv里面的梯度,现在应该不为0了')
print(conv.weight.grad)
print(conv.bias.grad)
print('清空conv的梯度,进行下一次测试')
conv.weight.grad.data.zero_()
conv.bias.grad.data.zero_()
print('查看conv里面的梯度,现在应该全为0了')
print(conv.weight.grad)
print(conv.bias.grad)
y = checkpoint(seg1, x).mean() - 3
try:
print('此时应当会失败,y并不是计算图的一部分,因为x的requires_grad为False,checkpoint认为这段函数是不需要计算梯度的')
y.backward()
except RuntimeError as e:
print('backward果然抛出异常了')
print('查看conv里面的梯度,现在应该保持不变,仍然全为0了')
print(conv.weight.grad)
print(conv.bias.grad)
print('让输入的requires_grad为True,有俩个办法,一个是直接设定x的requires_grad为True,另外一个办法就是与另外一个requires_grad为True的常量合并操作')
print('这里使用的是合并操作,因为有时候并不能直接设置输入的requires_grad=True,另外我认为合并操作占用的显存更少,因为grad的shape跟原始变量是一样的'
',使用合并操作,额外无用的grad的size只有1,而设定输入的requires_grad为True,额外无用的grad的size跟输入一样大')
x2 = x + torch.zeros(1, dtype=x.dtype, device=x.device, requires_grad=True)
y = checkpoint(seg1, x2).mean() - 3
y.backward()
print('现在backward不会报错了')
print('查看conv里面的梯度,现在不为0了')
print(conv.weight.grad)
print(conv.bias.grad)
print('实验完成')