最近在看transformer,以及相关的分布式实现,发现有些问题不太明白,顺便记录下,如有错误欢迎大佬指正。
在attention的介绍中(http://nlp.seas.harvard.edu/2018/04/03/attention.html),对并行化的训练给出了代码,但没有做详细的介绍,这里仔细的分析下相关代码:
# Skip if not interested in multigpu.
class MultiGPULossCompute:
"A multi-gpu loss compute and train function."
def __init__(self, generator, criterion, devices, opt=None, chunk_size=5):
# Send out to different gpus.
self.generator = generator
self.criterion = nn.parallel.replicate(criterion,
devices=devices)
self.opt = opt
self.devices = devices
self.chunk_size = chunk_size
def __call__(self, out, targets, normalize):
total = 0.0
# 复制linear激活函数
generator = nn.parallel.replicate(self.generator,
devices=self.devices)
# 按batch切分decoder输出
out_scatter = nn.parallel.scatter(out,
target_gpus=self.devices)
out_grad = [[] for _ in out_scatter]
# 切分label
targets = nn.parallel.scatter(targets,
target_gpus=self.devices)
# Divide generating into chunks.
chunk_size = self.chunk_size
# 切分句子
for i in range(0, out_scatter[0].size(1), chunk_size):
# Predict distributions
# 将decoder输出按句子切分
out_column = [[Variable(o[:, i:i + chunk_size].data,
requires_grad=self.opt is not None)]
for o in out_scatter]
# 执行linear函数
gen = nn.parallel.parallel_apply(generator, out_column)
# Compute loss.
y = [(g.contiguous().view(-1, g.size(-1)),
t[:, i:i + chunk_size].contiguous().view(-1))
for g, t in zip(gen, targets)]
loss = nn.parallel.parallel_apply(self.criterion, y)
# Sum and normalize loss
l = nn.parallel.gather(loss,
target_device=self.devices[0])
l = l.sum() / normalize
# print(l)
total += l
# Backprop loss to output of transformer
if self.opt is not None:
l.backward()
for j, l in enumerate(loss):
out_grad[j].append(out_column[j][0].grad.data.clone())
# Backprop all loss through transformer.
if self.opt is not None:
out_grad = [Variable(torch.cat(og, dim=1)) for og in out_grad]
o1 = out
o2 = nn.parallel.gather(out_grad,
target_device=self.devices[0])
o1.backward(gradient=o2)
self.opt.step()
self.opt.optimizer.zero_grad()
return total * normalize
其中,有一些torch的并行化原语:
scatter - split batches onto different gpus
parallel_apply - apply module to batches on different gpus
gather - pull scattered data back onto one gpu.
nn.DataParallel - a special module wrapper that calls these all before evaluating.
官方的解释:
replicate: replicate a Module on multiple devices
scatter: distribute the input in the first-dimension
gather: gather and concatenate the input in the first-dimension
parallel_apply: apply a set of already-distributed inputs to a set of already-distributed models.
其中核心代码在__call__函数中,此处分别按数据并行和句子并行进行了分布式的计算。
先介绍下attention的数据维度,这里用的de-en翻译数据IWSLT,原语言是de,目标语言是en,de语vocab大小58947,en语vocab大小36323。
代码构造了一个dataset,其中包含几个变量:
src:输入语言的向量,shape=[batch_size, 输入句子长度],每一位是一个int数据,表示单词在词表里的idx
src_mask:输入语言的mask向量,shape=[batch_size, 输入句子长度],每一位是一个bool值,表示是否mask这个数据
trg:encode的输入向量,shape=[batch_size, 输出句子长度],每一位是一个int数据,表示单词在词表里的idx
trg_mask:输出语言的mask向量,shape=[batch_size, 输出句子长度],每一位是一个bool值,表示是否mask这个数据
trg:label,需要学习的向量(对应encode的输出向量),shape=[batch_size, 输出句子长度],每一位是一个int数据,表示单词在词表里的idx,就是trg向量中,每个单词向后挪一位(为了预测目标语言的下一个单词)
为什么会有trg和trg_y?
对于训练来说(Teaching Forcing模式),Decoder有一个输入和一个输出。比如句子” it is a good day “,输入会变成” it is a good day",而输出为"it is a good day "。对应到代码里,self.trg就是输入,而self.trg_y就是输出。接着对输入self.trg进行mask,使得Self-Attention不能访问未来的输入。这是通过make_std_mask函数实现的,这个函数会调用我们之前详细介绍过的subsequent_mask函数。最终得到的trg_mask的shape是(48/batch, 24, 24),表示24个时刻的Mask矩阵,这是一个对角线以及之下都是1的矩阵,前面已经介绍过了。
可以看真实训练的时候,src的mask所有的值都是True,但trg_mask只有当前能看到的词的mask才是True,看不到的都是False
其中,数据并行很好理解:将数据按batch拆分,分别送入不同的GPU进行推理,并计算返回值
句子并行:对句子进行拆分,将前后两段拆开,先对前面一段进行线性预测,利用得到的结果继续推理后一段。
举个例子:有两句话:
David Gallo: Das ist Bill Lange. Ich bin Dave Gallo.
Wir werden Ihnen einige Geschichten über das Meer in Videoform erzählen.
David Gallo: Das ist Bill
Wir werden Ihnen einige Geschichten
进行预测,其中第一句的前半部分在GPU0中预测,第二句的前半部分在GPU1中预测。(这里是句子并行)
注意:其实句子并行化只在decoder输出之后进行,也就是只对线性推理部分进行了并行;
encoder-decoder的并行计算其实还是数据并行的,核心来自于:
model_par = nn.DataParallel(model, device_ids=devices)
这一句话,详细关于DataParallel的介绍,可以参考https://zhuanlan.zhihu.com/p/102697821
相关资料:
https://zhuanlan.zhihu.com/p/118601295