GRU的源码笔记

输出和状态是一样的,前一个状态为state,前一个输出也是state,其宽度都是num_units参数
重置门和更新门分别是r和u
首先输入和前一个输出拼接在一起,然后加权(_gate_kernel)再按列平分(因为r,u都是对状态的加权,所以宽度和状态的宽度一样,都是num_units参数),得到重置门r和更新门u

    gate_inputs = math_ops.matmul(
        array_ops.concat([inputs, state], 1), self._gate_kernel)
    gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias)

    value = math_ops.sigmoid(gate_inputs)
    r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)

然后对state进行重置(遗忘)

r_state = r * state

遗忘之后的状态和输入拼接在一起,加权(_candidate_kernel)得到候选状态candidate,接着激活

    candidate = math_ops.matmul(
        array_ops.concat([inputs, r_state], 1), self._candidate_kernel)
    candidate = nn_ops.bias_add(candidate, self._candidate_bias)

    c = self._activation(candidate)

最后对原状态和候选状态各取一定比例叠加在一起,得到新状态,和新输出

    new_h = u * state + (1 - u) * c
    return new_h, new_h

你可能感兴趣的:(LSTM)