open-nmt参数max_generator_batches

本文关于onmt的一个参数:

【max_generator_batches】

 

该参数被设置为默认32:

(下图为在onmt开源代码的opt.py参数文件中的默认设置)

help文档意为:

max_generator_batches为一个序列中并行运行生成器的最大的单词数量。越高越快,但占用的内存越大。设置为0禁用。

第一次看到的时候有点懵,反复确认代码后,决定将其暂时理解为模型对于一个输入做序列输出时,不再是一条线的按顺序生成,而是多条线并行生成,每条线包含32words。

    group.add('--max_generator_batches', '-max_generator_batches',
              type=int, default=32,
              help="Maximum batches of words in a sequence to run "
                   "the generator on in parallel. Higher is faster, but "
                   "uses more memory. Set to 0 to disable.")

 

该参数对应的代码部分:

train.py中

【shard_size被赋值 = opt.max_generator_batches = 32】

train_loss = onmt.utils.loss.build_loss_compute(model, tgt_field, opt)

shard_size = opt.max_generator_batches if opt.model_dtype == 'fp32' else 0
loss, batch_stats = self.train_loss( 
                        batch,
                        outputs,
                        attns,
                        normalization=normalization,
                        shard_size=self.shard_size,
                        trunc_start=j,
                        trunc_size=trunc_size) 
# 当shard_size!=0时,loss为None,在函数内部loss回传;***!!!
# 当shard_size==0时,loss 有值,在如下代码中loss回传。***!!!
if loss is not None:
    self.optim.backward(loss) 

loss.py中

以下代码是train_loss函数

def build_loss_compute(model, tgt_field, opt, train=True):
    # XXXXX 此处省略部分代码
    if opt.copy_attn: 
        # 我的代码走的是这里
        compute = onmt.modules.CopyGeneratorLossCompute(
            criterion, loss_gen, tgt_field.vocab, opt.copy_loss_by_seqlength,
            lambda_coverage=opt.lambda_coverage
        ) # (此处为CopyGeneratorLossCompute 的 init)
    else:
        compute = NMTLossCompute(
            criterion, loss_gen, lambda_coverage=opt.lambda_coverage)
    compute.to(device)
    return compute

copy_generator.py中

在以下代码中的_compute_loss函数中返回loss

该类的基类为->NMTLossCompute->LossComputeBase

class CopyGeneratorLossCompute(NMTLossCompute):
    """Copy Generator Loss Computation."""
    def _compute_loss(self, batch, output, target, copy_attn, align,
                      std_attn=None, coverage_attn=None):
        """Compute the loss.
        The args must match :func:`self._make_shard_state()`.
        Args:
            batch: the current batch.
            output: the predict output from the model.
            target: the validate target to compare output with.
            copy_attn: the copy attention value.
            align: the align info.
        """
        scores = self.generator(self._bottle(output), self._bottle(copy_attn), batch.src_map)
        loss = self.criterion(scores, align, target)
        # this block does not depend on the loss value computed above
        # and is used only for stats
        scores_data = collapse_copy_scores(
            self._unbottle(scores.clone(), batch.batch_size),
            batch, self.tgt_vocab, None)
        scores_data = self._bottle(scores_data)
        # this block does not depend on the loss value computed above and is used only for stats
        # Correct target copy token instead of 
        # tgt[i] = align[i] + len(tgt_vocab)
        # for i such that tgt[i] == 0 and align[i] != 0
        target_data = target.clone()
        unk = self.criterion.unk_index
        correct_mask = (target_data == unk) & (align != unk)
        offset_align = align[correct_mask] + len(self.tgt_vocab)
        target_data[correct_mask] += offset_align

        # Compute sum of perplexities for stats
        stats = self._stats(loss.sum().clone(), scores_data, target_data)
        loss = loss.sum()
        return loss, stats

然而这个函数是被什么调用的呢?

接下来查看loss.py中的基类LossComputeBase

可以轻松看到,在该基类的系统__call__函数中,调用了上一份代码中的_compute_loss函数

即有关于shard的主要代码在以下这里:

class LossComputeBase(nn.Module):
    """
    Handles sharding next step predictions and accumulating multiple loss computations
    Users can implement their own loss computation strategy by making subclass of this one.  
    Users need to implement the _compute_loss()  and make_shard_state() methods.
    """
    def __call__(self, batch, output, attns, normalization=1.0, shard_size=0, trunc_start=0, trunc_size=None): # shard_size默认为0 
        """Compute the forward loss, possibly in shards in which case this
        method also runs the backward pass and returns ``None`` as the loss value.
 
        Note sharding is an exact efficiency trick to relieve memory required for the generation buffers. 
        Truncation is an approximate efficiency trick to relieve the memory required in the RNN buffers. 释放生成缓冲区所需的内存
        Args:
          batch (batch) : batch of labeled examples
          output (:obj:`FloatTensor`) :output of decoder `[tgt_len x batch x hidden]`
          attns (dict) : `[tgt_len x batch x src_len]`
          shard_size (int) : maximum number of examples in a shard
        Returns:
            A tuple with the loss and a :obj:`onmt.utils.Statistics` instance.
        """
        if trunc_size is None:
            trunc_size = batch.tgt.size(0) - trunc_start
        trunc_range = (trunc_start, trunc_start + trunc_size) 
        shard_state = self._make_shard_state(batch, output, trunc_range, attns)

        if shard_size == 0: # 为0时,返回loss值
            loss, stats = self._compute_loss(batch, **shard_state)
            return loss / float(normalization), stats

        # shard_size != 0 loss直接回传计算(backward),不再返回loss值
        batch_stats = onmt.utils.Statistics() 
        for shard in shards(shard_state, shard_size):
            loss, stats = self._compute_loss(batch, **shard)
            loss.div(float(normalization)).backward() # backward 2/2 -------------- loss 2/2 --------------
            batch_stats.update(stats)  
        return None, batch_stats

上述代码我们可以看到

 

在shard_size == 0时,则只有一次loss.backward()。

在本文上面的train.py的第一份代码中,我们可以看到

if loss is not None:
    self.optim.backward(loss) 

shard_size == 0时,唯一的loss.backward()在train.py中,这是我们常见的形式【先利用函数算出loss的值,然后loss.backward】

 

在shard_size != 0时,代码使用到了shard函数,代码进行了两次backward。

def shards(state, shard_size, eval_only=False):
    """
    Args:
        shard_size: The maximum size of the shards yielded by the model.
        eval_only: If True, only yield the state, nothing else.
              Otherwise, yield shards.
    Yields:
        Each yielded shard is a dict.
    Side effect:
        After the last shard, this function does back-propagation.
    """
    if eval_only:
        yield filter_shard_state(state)
    else:
        non_none = dict(filter_shard_state(state, shard_size)) # non_none: 由state dictionary中值非None组成的subdict.
        # non_none是一个sequences of tensor-like的字典,但我们需要一序列的dictionaries of tensors。首先,将字典解压缩成一个键序列和一个tensor-like 序列。
        keys, values = zip(*((k, [v_chunk for v_chunk in v_split])
                             for k, (_, v_split) in non_none.items()))
        # 为each shard生成一个字典。keys是一样的。
        # values is a sequence of length #keys 
        # where each element is a sequence of length #shards. 
        # 我们希望遍历shard,而不是keys,因此,需要按照shard对values进行重新压缩,这样每个shard可以与keys匹配。
        for shard_tensors in zip(*values):
            yield dict(zip(keys, shard_tensors))
        # Assumed backprop'd
        variables = []
        for k, (v, v_split) in non_none.items():
            if isinstance(v, torch.Tensor) and state[k].requires_grad: 
                variables.extend(zip(torch.split(state[k], shard_size), [v_chunk.grad for v_chunk in v_split]))
        inputs, grads = zip(*variables) # inputs : tuple 
        torch.autograd.backward(inputs, grads) # backward 1/2 -------------- loss 1/2 -------------- 

在shards函数中,使用yield生成了切片后的数据并送给call函数后,在shards函数内部,对所有requires_grad的Tensor进行了torch.autograd.backward(inputs, grads)。

然后在call函数中,对每一份切片(shard)后数据计算出的loss还有一次 loss.div(float(normalization)).backward();

我们可以看到,在shard_size != 0时,代码进行了两次backward。

其中

shards函数 使用到了如下函数代码,该代码的作用是将输入state划分为k和v,然后对v进行切片

def filter_shard_state(state, shard_size=None):
    for k, v in state.items():
        if shard_size is None:
            yield k, v

        if v is not None:
            v_split = []
            if isinstance(v, torch.Tensor):
                for v_chunk in torch.split(v, shard_size):
                    v_chunk = v_chunk.data.clone()
                    v_chunk.requires_grad = v.requires_grad
                    v_split.append(v_chunk)
            yield k, (v, v_split)

为什么要两次backward呢?

因为这行代码 v_chunk = v_chunk.data.clone()

.data:获取 Variable 的 内部 Tensor,并脱离计算图,求导时输出错误结果0,但是不会报错。

.clone():复制一个完全一样的Tensor并添加在计算图中,不脱离计算图

filter_shard_state将loss函数的输入.data进行clone,所以loss.backward()计算梯度,只能传播到loss函数的输入,只能传播到这里,而不是整个计算图。

而torch.autograd.backward使用之前backward函数计算梯度,以及整个计算图的输入,来计算剩余变量的梯度。

 

为什么v_chunk = v_chunk.data.clone()之后又要v_chunk.requires_grad = v.requires_grad呢?

 

你可能感兴趣的:(开源工具)