##概述
最近几天有个模型,使用了local adaptation,就是在inference阶段也加入了训练的过程,因为需要进行模型复制和较大的batch_size进行训练,这样就导致我显卡直接爆炸。我用的看的K80,显存12G。模型用的BERT+文本分类,输入的Sequence_size=512(下简写为seq_size)。
##思考过程
接下来分享一下我的思考过程,可以看到就模型本身是不大的,即使扩大一倍也不会爆显存,问题就出在这个输入上,seq_size较大的时候,batch_size就要注意了,不然会导致计算的时候矩阵过大,有可能爆显存。我一直把inference的input_batch_size调小到1,都还是不行。说明目前来看不是输入位置的和前向计算过程中导致的爆显存。
接下来调试代码到local adaptation部分,这里会从memory里匹配出32个数据来进行模型参数的更新,我在训练阶段就没有设置过这么大的batch_size,因为一定会爆,问题应该就出在这里了。
##解决过程
首先想到的就是直接开并行,在model的最外层直接用DataParallel封装。使用4个GPU并行处理,发现不行,在local adaptation的地方出现数据分发的错误,源代码很长,就改成一个类似伪代码的片段给大家展示一下
import torch
class Net(torch.nn.Module):
def __init__(model_path):
super(Net, self).__init__()
self.classify = load_model(model_path)
def forward(input):
self.local_classify = copy.deepcopy(self.classify)
logits = self.train(self.local_classify)
input = input.cuda()
return logits
def train(model,data):
#TODO: train and update model
return logits
def main():
model = Net(model_path)
model = torch.nn.DataParallel(model).cuda()
input = load_data(data_path)
output = model(input)
if __name__=='__main__':
main()
执行代码使用如下:
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py
于是我改回用单卡跑,将训练地方自己手动拆分batch为更小的batch分多次计算后再合并loss计算(这个部分我用完一个矩阵就会立马释放),仍然爆显存。看来不得不使用多GPU并行了。
先说一下DataParallel的机制,属于数据并行,首先这是一个负载不均衡的并行,会首先将模型(M)复制成相同的N份到每张GPU上,得到{[M1,M2,...,Mi,...,MN],Mi=M,i=[1,N]},然后再将batch_size等分为N份,分别送到N张GPU上计算,计算完毕后汇总loss,在主卡上(默认为device:0)进行梯度计算并更新模型M1,更新完M1后,再将其重新复制剩下的N-1张卡上,这样就逐次迭代就完成了并行。
那么在pytorch里进行GPU并行是需要指定GPU的编号的,我们用torch.device('cuda')可将模型传到GPU上,默认情况下,不指定编号,就是会放在device 0上,在本代码中出现了两个模型,一个需要训练(称为train_model),一个不需要训练(称为static_model),那么我们最好将其放在不同的GPU上,防止训练阶段负载不均衡的时候影响到static_model。于是来该代码,详细设置需要GPU计算的地方的参数,这里仍然是类似伪代码,简短点争取让大家看清楚核心部分。
import torch
class Net(torch.nn.Module):
def __init__(model_path):
super(Net, self).__init__()
self.classify = load_model(model_path)
def forward(input):
self.local_classify = copy.deepcopy(self.classify)
#cuda后面跟编号的时候就设置哪种卡为当前主卡
device = torch.device('cuda:1')
#这里注意看device_ids这个参数,是指定在哪几张卡上并行的
torch.nn.DataParallel(self.local_classify,device_ids=[1,2,3,4]).to(device)
data = self.get_memory()
data = data.to(device)
logits = self.train(self.local_classify,data)
return logits
def train(model,data):
#TODO: train and update model
return logits
def get_memory(batch_size=32):
#TODO: get samples from memory
return memory
def main():
model = Net(model_path)
model = torch.nn.cuda()#这里会默认使用device:0
input = load_data(data_path)
input = input.cuda()
output = model(input)
if __name__=='__main__':
main()
##结果
最后呢,改成上面这种模式,开了5张卡,卡0放static_model,卡1-4放train_model进行训练,就很好的解决了计算部分导致的显存爆炸。
这种多个模型同时在不同卡之间跑,如果有数据上的共享和同步,记得一定要注意数据当时存放在哪。
整理完这个bug,发现在并行前,最好还是使用torch.device和device_ids来指定GPU使用,方便管理。
代码等我调试完所有的部分后传到github上进行分享。再次感谢大家