关于convlstm

这里对lstm的解释挺好:https://zhuanlan.zhihu.com/p/32085405

https://blog.csdn.net/weixin_42769131/article/details/104728842

class ConvLSTMCell(nn.Module):
    """
    Generate a convolutional LSTM cell
    """

    def __init__(self, input_size, hidden_size):
        super(ConvLSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, kernel_size=3, stride=1, padding=1)

    def forward(self, input_, prev_state):

        # get batch and spatial sizes
        batch_size = input_.data.size()[0]
        spatial_size = input_.data.size()[2:]

        # generate empty prev_state, if None is provided
        if prev_state is None:
            state_size = [batch_size, self.hidden_size] + list(spatial_size)
            prev_state = (
                torch.zeros(state_size).cuda(),
                torch.zeros(state_size).cuda()
            )

        prev_hidden, prev_cell = prev_state

        # data size is [batch, channel, height, width]
        stacked_inputs = torch.cat((input_, prev_hidden), 1)
        gates = self.Gates(stacked_inputs)

        # chunk across channel dimension
        in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1) 
        # cell_gate, 表示当前的输入xt和前面时刻的输出 的和,tanh拉到-1~1之间 是输入数据
        # 忘记阶段:remember_gate 遗忘门,控制上一个细胞状态留下多少信息,
        # 选择记忆:in_gate 对当前的输入信息(information) xt有选择的进行记忆,
        # 输出阶段:out_gate 决定哪些作为当前状态的输出

        # apply sigmoid non linearity
        in_gate = F.sigmoid(in_gate)
        remember_gate = F.sigmoid(remember_gate)
        out_gate = F.sigmoid(out_gate)

        # apply tanh non linearity
        cell_gate = F.tanh(cell_gate) # -1~1 之间的特行,这是作为输入数据而不是门控信号

        # compute current cell and hidden state
        cell = (remember_gate * prev_cell) + (in_gate * cell_gate)
        hidden = out_gate * F.tanh(cell)

        return hidden, cell

你可能感兴趣的:(关于convlstm)