使用pytorch从头实现多层LSTM

代码如下:

#自定义LSTM实现
class NaiveCustomLSTM(nn.Module):

    def __init__(self,input_size,hidden_size,num_layers=2):
        super().__init__()
        self.input_size = input_size
        self.hidden_size= hidden_size
        self.num_layers = num_layers
        #self._all_weights = {}
        self.param_names = []
        for layer in range(self.num_layers):
            self.input_size = self.input_size if layer == 0 else self.hidden_size #* num_directions
            # i_t
            W_i = nn.Parameter(torch.Tensor(self.input_size, self.hidden_size))  # .to(x.device)
            U_i = nn.Parameter(torch.Tensor(self.hidden_size, self.hidden_size))  # .to(x.device)
            b_i = nn.Parameter(torch.Tensor(self.hidden_size))  # .to(x.device)

            # f_t
            W_f = nn.Parameter(torch.Tensor(self.input_size, self.hidden_size))  # .to(x.device)
            U_f = nn.Parameter(torch.Tensor(self.hidden_size, self.hidden_size))  # .to(x.device)
            b_f = nn.Parameter(torch.Tensor(self.hidden_size))  # .to(x.device)

            # c_t
            W_c = nn.Parameter(torch.Tensor(self.input_size, self.hidden_size))  # .to(x.device)
            U_c = nn.Parameter(torch.Tensor(self.hidden_size, self.hidden_size))  # .to(x.device)
            b_c = nn.Parameter(torch.Tensor(self.hidden_size))  # .to(x.device)

            # o_t
            W_o = nn.Parameter(torch.Tensor(self.input_size, self.hidden_size))  # .to(x.device)
            U_o = nn.Parameter(torch.Tensor(self.hidden_size, self.hidden_size))  # .to(x.device)
            b_o = nn.Parameter(torch.Tensor(self.hidden_size))  # .to(x.device)

            #print(self.W_c)
            layer_params = (W_i, U_i,W_f,U_f,W_c,U_c,W_o,U_o,b_i,b_f,b_c,b_o)

            suffix = ''
            self.param_name = ['weight_W_i{}{}', 'weight_U_i{}{}','weight_W_f{}{}','weight_U_f{}{}',
                               'weight_W_c{}{}','weight_U_c{}{}','weight_W_o{}{}','weight_U_o{}{}']
            #if bias:
            self.param_name += ['bias_b_i{}{}', 'bias_b_f{}{}','bias_b_c{}{}','bias_b_o{}{}']
            self.param_name = [x.format(layer, suffix) for x in self.param_name]
            for name, param in zip(self.param_name, layer_params):
                setattr(self, name, param)
            self.param_names.append(self.param_name)

            #print(self.param_names)
            #for name, param in zip(self.param_name, layer_params):
                #self._all_weights[name] = param
        self.init_weights()
        #self.all_weights = [[getattr(self, weight) for weight in weights] for weights in self.param_names]


    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            torch.nn.init.uniform_(weight, -stdv, stdv)

    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    #def

    def forward(self,
                x,
                init_states=None):

        """
        assumes x.shape represents (batch_size, sequence_size, input_size)
        """

        bs, seq_sz, _ = x.size()
        hidden_seqs = 0

        if init_states is None:
            h_t, c_t = (
                torch.zeros(bs, self.hidden_size).to(x.device),
                torch.zeros(bs, self.hidden_size).to(x.device),
            )
        else:
            h_t, c_t = init_states
        for layer in range(self.num_layers):
            #print(self.all_weights[0][0])
            #for i in range(len(self.param_names)):
            #self.weight = self.all_weights[layer]
            #print(self.param_name)
            #print(layer)
            #print(self._all_weights)
            #if self.param_names
            suffix = ''
            param_name = ['weight_W_i{}{}', 'weight_U_i{}{}', 'weight_W_f{}{}', 'weight_U_f{}{}',
                               'weight_W_c{}{}', 'weight_U_c{}{}', 'weight_W_o{}{}', 'weight_U_o{}{}']
            # if bias:
            param_name += ['bias_b_i{}{}', 'bias_b_f{}{}', 'bias_b_c{}{}', 'bias_b_o{}{}']
            param_name = [x.format(layer, suffix) for x in param_name]

            #print(param_name[0])
            #W_i = self.param_name[0]
            #print(W_i)
            self.param_name = self.param_names[layer]
            """
            print(getattr(self,self.param_name[0]))
            w_i = self._all_weights[self.param_name[0]]
            U_i = self._all_weights[self.param_name[1]]

            W_f = self._all_weights[self.param_name[2]]
            U_f = self._all_weights[self.param_name[3]]

            W_c = self._all_weights[self.param_name[4]]
            U_c = self._all_weights[self.param_name[5]]

            W_o = self._all_weights[self.param_name[6]]
            U_o = self._all_weights[self.param_name[7]]

            b_i = self._all_weights[self.param_name[8]]
            b_f = self._all_weights[self.param_name[9]]
            b_c = self._all_weights[self.param_name[10]]
            b_o = self._all_weights[self.param_name[11]]
            """
            hidden_seq = []
            for t in range(seq_sz):
                x_t = x[:, t, :]
                #print(x_t.shape,self.W_i.shape)
                """
                i_t = torch.sigmoid(x_t @ self.W_i + h_t @ self.U_i + self.b_i)
                f_t = torch.sigmoid(x_t @ self.W_f + h_t @ self.U_f + self.b_f)
                g_t = torch.tanh(x_t @ self.W_c + h_t @ self.U_c + self.b_c)
                o_t = torch.sigmoid(x_t @ self.W_o + h_t @ self.U_o + self.b_o)
                """
                i_t = torch.sigmoid(x_t @ getattr(self,self.param_name[0])+ h_t @ getattr(self,self.param_name[1])+ getattr(self,self.param_name[8]))
                f_t = torch.sigmoid(x_t @ getattr(self,self.param_name[2]) + h_t @ getattr(self,self.param_name[3]) + getattr(self,self.param_name[9]))
                g_t = torch.tanh(x_t @ getattr(self,self.param_name[4]) + h_t @ getattr(self,self.param_name[5]) + getattr(self,self.param_name[10]))
                o_t = torch.sigmoid(x_t @ getattr(self,self.param_name[6]) + h_t @ getattr(self,self.param_name[7]) + getattr(self,self.param_name[11]))
                c_t = f_t * c_t + i_t * g_t
                h_t = o_t * torch.tanh(c_t)
                h_t = h_t[0]
                hidden_seq.append(h_t.unsqueeze(1))
                #print(np.array(hidden_seq).shape)

            # reshape hidden_seq p/ retornar
            hidden_seqs = torch.cat(hidden_seq, dim=1)
            #print(hidden_seqs.shape)
            x = hidden_seqs

        return hidden_seqs, (h_t, c_t)

你可能感兴趣的:(pytorch,pytorch,lstm,深度学习)