convolutional LSTM(convLSTM)是《Convolutional LSTM Network: A Machine Learning Approach for Precipitation Nowcasting》一文提出的,用于降水预测。这一网络结构,既考虑了输入之间的空间关联,也考虑了时序信息,因此,也被用于视频分析。
github上已经有了许多个convLSTM的pytorch实现,这里选择Convolution_LSTM_pytorch进行调试运行。
文件中定义了ConvLSTM
和ConvLSTMCell
两个类,并给出了一段调用代码。
包含__init__
和forward
两个函数。
__init__
:根据输入参数定义一个多层的convLSTM
def __init__(self, input_channels, hidden_channels, kernel_size, step=1, effective_step=[1]):
super(ConvLSTM, self).__init__()
self.input_channels = [input_channels] + hidden_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.num_layers = len(hidden_channels)
self.step = step
self.effective_step = effective_step
self._all_layers = []
for i in range(self.num_layers): # 定义一个多层的convLSTM(即多个convLSTMCell),并存放在_all_layers列表中
name = 'cell{}'.format(i)
cell = ConvLSTMCell(self.input_channels[i], self.hidden_channels[i], self.kernel_size)
setattr(self, name, cell)
self._all_layers.append(cell)
forward
:一个多层convLSTM的多时步前向传播
def forward(self, input):
internal_state = []
outputs = []
for step in range(self.step): # 在每一个时步进行前向运算
x = input
for i in range(self.num_layers): # 对多层convLSTM中的每一层convLSTMCell,依次进行前向运算
# all cells are initialized in the first step
name = 'cell{}'.format(i)
if step == 0: # 如果是在第一个时步,则需要调用init_hidden进行convLSTMCell的初始化
bsize, _, height, width = x.size()
(h, c) = getattr(self, name).init_hidden(batch_size=bsize, hidden=self.hidden_channels[i],
shape=(height, width))
internal_state.append((h, c))
# do forward
(h, c) = internal_state[i]
x, new_c = getattr(self, name)(x, h, c) # 调用convLSTMCell的forward进行前向运算
internal_state[i] = (x, new_c)
# only record effective steps
if step in self.effective_step:
outputs.append(x)
return outputs, (x, new_c)
包含__init__
、forward
和init_hidden
三个函数。
__init__
:初始化一个LSTM单元
def __init__(self, input_channels, hidden_channels, kernel_size):
super(ConvLSTMCell, self).__init__()
assert hidden_channels % 2 == 0
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.num_features = 4
self.padding = int((kernel_size - 1) / 2)
self.Wxi = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
self.Whi = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
self.Wxf = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
self.Whf = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
self.Wxc = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
self.Whc = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
self.Wxo = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
self.Who = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
self.Wci = None
self.Wcf = None
self.Wco = None
forward
:一个LSTM单元里的前向传播,即convLSTM中最核心的5个公式,输出的ch
&cc
分表代表current hidden_state & current cell_state
def forward(self, x, h, c):
ci = torch.sigmoid(self.Wxi(x) + self.Whi(h) + c * self.Wci)
cf = torch.sigmoid(self.Wxf(x) + self.Whf(h) + c * self.Wcf)
cc = cf * c + ci * torch.tanh(self.Wxc(x) + self.Whc(h))
co = torch.sigmoid(self.Wxo(x) + self.Who(h) + cc * self.Wco)
ch = co * torch.tanh(cc)
return ch, cc
init_hidden
:convLSTMCell的初始化,返回初始的hidden_state & cell_state
def init_hidden(self, batch_size, hidden, shape):
if self.Wci is None:
self.Wci = Variable(torch.zeros(1, hidden, shape[0], shape[1]))
self.Wcf = Variable(torch.zeros(1, hidden, shape[0], shape[1]))
self.Wco = Variable(torch.zeros(1, hidden, shape[0], shape[1]))
else:
assert shape[0] == self.Wci.size()[2], 'Input Height Mismatched!'
assert shape[1] == self.Wci.size()[3], 'Input Width Mismatched!'
return (Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])),
Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])))
if __name__ == '__main__':
# gradient check
# 定义一个5层的convLSTM
convlstm = ConvLSTM(input_channels=512, hidden_channels=[128, 64, 64, 32, 32], kernel_size=3, step=5,
effective_step=[4])
loss_fn = torch.nn.MSELoss()
input = Variable(torch.randn(1, 512, 64, 32))
target = Variable(torch.randn(1, 32, 64, 32)).double()
output = convlstm(input)
output = output[0][0].double()
res = torch.autograd.gradcheck(loss_fn, (output, target), eps=1e-6, raise_exception=True)
print(res)
如果需要在其他py文件中调用此模块,直接导入即可