pytorch显存管理、前向传播中间激活存储(intermediate activation)和torch.utils.checkpoint

参考

  • PyTorch显存机制分析
  • pytorch获得模型的参数量和模型的大小
  • TORCH.UTILS.CHECKPOINT
  • Training larger-than-memory PyTorch models using gradient checkpointing
  • Analysis of checkpoint mechanism of pytorch

前向传播的中间激活

最近希望能够在模型训练过程中改变中间的激活值,使改变后的中间激活值用于随后的反向传播中。
我们知道,反向传播需要使用前向传播的中间变量来计算梯度,而这些中间变量就存储在GPU的显存中,并且我现在还没找到如何从显存中将这些中间变量提取出来(知道的可以告诉我啊)。

  • 我们定义一个简单的模型如下:
import torch
from torch.utils.checkpoint import checkpoint

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.net1 = torch.nn.Linear(3, 300)
        self.net2 = torch.nn.Linear(300, 300)
        self.net3 = torch.nn.Linear(300, 400)
        self.net4 = torch.nn.Linear(400, 300)
        self.net5 = torch.nn.Linear(300, 100)
        self.activation_sum = 0
        self.activation_size = 0

    def forward(self, x):
        x = self.net1(x)
        self.activation_sum += x.nelement()
        self.activation_size += (x.nelement() * x.element_size())
        x = self.net2(x)
        self.activation_sum += x.nelement()
        self.activation_size += (x.nelement() * x.element_size())
        x = self.net3(x)
        self.activation_sum += x.nelement()
        self.activation_size += (x.nelement() * x.element_size())
        x = self.net4(x)
        self.activation_sum += x.nelement()
        self.activation_size += (x.nelement() * x.element_size())
        x = self.net5(x)
        self.activation_sum += x.nelement()
        self.activation_size += (x.nelement() * x.element_size())
        return x

可以看到,前向传播函数中,每一个x就是中间变量,最后一个x也就是结果,也算中间变量。所有的这些中间变量都会存储在显存中。我们在前向传播中,把中间结果的参数个数和总参数大小(字节为单位)存储为activation_sum 和 activation_size
我们接下来验证模型的中间变量存储在显存中。

  • 计算显存的使用:
    我们可以使用torch.cuda.memory_allocated()输出显存中存储的大小(字节为单位)。
  • 返回模型大小(字节为单位)的函数:
def modelSize(model):
    param_size = 0
    param_sum = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
        param_sum += param.nelement()
    buffer_size = 0
    buffer_sum = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
        buffer_sum += buffer.nelement()
    all_size = (param_size + buffer_size)
    return all_size
  • 定义输入
device = torch.device("cuda:0")

input = torch.randn(10, 3).to(device)
label = torch.randn(10, 100).to(device)
  • 前向传播和反向传播
torch.cuda.empty_cache()
before = torch.cuda.memory_allocated()
model = MyModel().to("cuda:0")
after = torch.cuda.memory_allocated()
print("建立模型后显存变大{}".format(after - before))

print("模型大小为{}".format(modelSize(model)))

loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
model.train()
optimizer.zero_grad()

before = torch.cuda.memory_allocated()
print("模型前向传播前使用显存为{}".format(before))

output = model(input)  # 前向传播

after = torch.cuda.memory_allocated()
print("模型前向传播后使用显存为{},差值(中间激活)为{}".format(after, after - before))

loss = loss_fn(output, label)
torch.autograd.backward(loss)
optimizer.step()

结果为:

建立模型后显存变大1452544
模型大小为1449200
模型前向传播前使用显存为1457152
模型前向传播后使用显存为1514496,差值(中间激活)为57344

打印一下统计的中间结果(intermediate activation)

print(model.activation_sum)
print(model.activation_size)

结果为:

14000
56000

可以看到,显存中的模型大小和模型实际大小基本一样,就是模型参数的大小。模型前向传播前和前向传播后的显存变大,而这个值就和模型的中间结果大小相同。
这也证明了模型的中间结果存储在了显存中,反向传播计算完后即释放

使用checkpoint

checkpoint就是以时间换存储,使用了checkpoint包起来的层(也可以是连续的层),前向传播时就不需要存储中间结果,而是在反向传播时,需要中间变量时重新计算。
例如我们将模型重写如下:

import torch
from torch.utils.checkpoint import checkpoint

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.net1 = torch.nn.Linear(3, 300)
        self.net2 = torch.nn.Linear(300, 300)
        self.net3 = torch.nn.Linear(300, 400)
        self.net4 = torch.nn.Linear(400, 300)
        self.net5 = torch.nn.Linear(300, 100)
        self.activation_sum = 0
        self.activation_size = 0

    def forward(self, x):
        x = self.net1(x)
        self.activation_sum += x.nelement()
        self.activation_size += (x.nelement() * x.element_size())

        x = checkpoint(torch.nn.Sequential(self.net2, self.net3, self.net4), x)
        self.activation_sum += x.nelement()
        self.activation_size += (x.nelement() * x.element_size())
        x = self.net5(x)
        self.activation_sum += x.nelement()
        self.activation_size += (x.nelement() * x.element_size())
        return x

checkpoint举例
我们用checkpoint把self.net2, self.net3, self.net4包起来,这样的话,在前向传播是会存储self.net1(x)这一个中间结果,也就是self.net2的输入,然后self.net2, self.net3的结果都不会被存储,self.net4的输出也就是self.net5的输入会被存储。

这次的结果是:

建立模型后显存变大1452544
模型大小为1449200
模型前向传播前使用显存为1457152
模型前向传播后使用显存为1485824,差值(中间激活)为28672

结果
我们统计的模型的activation_sum 和 activation_size分别为:7000 28000,与通过显存计算出来的激活量28672基本一致。因为显存肯定要存储一些其他的值,所以肯定不是完全相同。

你可能感兴趣的:(pytorch,deep,learning,python,pytorch,深度学习)