Pytorch LSTMCell踩坑

Pytorch LSTMCell踩坑

背景: 使用torch.nn.LSTMCell编写多层lstm网络。

self.lstm_cells = [nn.LSTMCell(self.embed_dim, self.hidden_dim).cuda()]
for i in range(num_layers - 1):
      self.lstm_cells.append(nn.LSTMCell(self.hidden_dim, self.hidden_dim).cuda())

问题描述: 列表lstm_cells为class Model的一个模块,参数含义顾名思义。但是model.state_dict()和model.Parameters()中并不包含lstm_cell中的参数(model为Model的实例化),这样在训练时无法更新权重参数。且在模型持久化和加载的时候无法对该模块进行操作。
解决办法: 使用nn.ModuleList()代替list进行操作。nn.ModuleList自动会将参数注册为Parameter,其他操作与list基本相同。

self.lstm_cells = nn.ModuleList([nn.LSTMCell(self.embed_dim, self.hidden_dim)])
    for i in range(num_layers - 1):
      self.lstm_cells.append(nn.LSTMCell(self.hidden_dim, self.hidden_dim))

你可能感兴趣的:(笔记,pytorch)