此文就LSTM的主要程序片段以及我认为的重要信息进行展开。大致结构如下,辅以一些if判断选择不同的lstm实现函数
# The input should be dense, padded with zeros. If a ragged input is fed
# into the layer, it is padded and the row lengths are used for masking.
inputs, row_lengths = K.convert_inputs_if_ragged(inputs)
is_ragged_input = (row_lengths is not None)
self._validate_args_if_ragged(is_ragged_input, mask)
# LSTM does not support constants. Ignore it during process.
inputs, initial_state, _ = self._process_inputs(inputs, initial_state, None)
input_shape = K.int_shape(inputs)
timesteps = input_shape[0] if self.time_major else input_shape[1]
last_output, outputs, states = K.rnn(
step,
inputs,
initial_state,
constants=None,
go_backwards=self.go_backwards,
mask=mask,
unroll=self.unroll,
input_length=row_lengths if row_lengths is not None else timesteps,
time_major=self.time_major,
zero_output_for_mask=self.zero_output_for_mask)
#return_sequences
if self.return_sequences:
output = K.maybe_convert_to_ragged(is_ragged_input, outputs, row_lengths)
else:
output = last_output
def step(cell_inputs, cell_states):
"""Step function that will be used by Keras RNN backend."""
h_tm1 = cell_states[0] # previous memory state
c_tm1 = cell_states[1] # previous carry state
z = K.dot(cell_inputs, kernel) # kernel:weights for cell kernel
z += K.dot(h_tm1, recurrent_kernel)
z = K.bias_add(z, bias)
z0, z1, z2, z3 = array_ops.split(z, 4, axis=1)
print(z0,z1,z2,z3)
i = recurrent_activation(z0) #input gate
f = recurrent_activation(z1) #forget gate
c = f * c_tm1 + i * activation(z2) #next carry state
o = recurrent_activation(z3) #output gate
h = o * activation(c) #next memory state
return h, [h, c]
last_output, outputs, new_states = K.rnn(
step,
inputs, [init_h, init_c],
constants=None,
unroll=False,
time_major=time_major,
mask=mask,
go_backwards=go_backwards,
input_length=(sequence_lengths
if sequence_lengths is not None else timesteps),
zero_output_for_mask=zero_output_for_mask)
return (last_output, outputs, new_states[0], new_states[1],
_runtime(_RUNTIME_CPU))
def rnn(step_function,
inputs,
initial_states,
go_backwards=False,
mask=None,
constants=None,
unroll=False,
input_length=None,
time_major=False,
zero_output_for_mask=False):
"""Iterates over the time dimension of a tensor.
input_length: An integer or a 1-D Tensor, depending on whether
the time dimension is fixed-length or not. In case of variable length
input, it is used for masking in case there's no mask specified.
在我们的调用返回当中,此处的input_length即为每一层的"""
这里的step_function是RNN step function,如前述的standart_lstm中的step()函数
#循环计算每个时间步
for i in range(time_steps):
inp = _get_input_tensor(i)
mask_t = mask_list[i]
output, new_states = step_function(inp,
tuple(states) + tuple(constants))