代码如下:
#自定义LSTM实现
class NaiveCustomLSTM(nn.Module):
def __init__(self,input_size,hidden_size,num_layers=2):
super().__init__()
self.input_size = input_size
self.hidden_size= hidden_size
self.num_layers = num_layers
#self._all_weights = {}
self.param_names = []
for layer in range(self.num_layers):
self.input_size = self.input_size if layer == 0 else self.hidden_size #* num_directions
# i_t
W_i = nn.Parameter(torch.Tensor(self.input_size, self.hidden_size)) # .to(x.device)
U_i = nn.Parameter(torch.Tensor(self.hidden_size, self.hidden_size)) # .to(x.device)
b_i = nn.Parameter(torch.Tensor(self.hidden_size)) # .to(x.device)
# f_t
W_f = nn.Parameter(torch.Tensor(self.input_size, self.hidden_size)) # .to(x.device)
U_f = nn.Parameter(torch.Tensor(self.hidden_size, self.hidden_size)) # .to(x.device)
b_f = nn.Parameter(torch.Tensor(self.hidden_size)) # .to(x.device)
# c_t
W_c = nn.Parameter(torch.Tensor(self.input_size, self.hidden_size)) # .to(x.device)
U_c = nn.Parameter(torch.Tensor(self.hidden_size, self.hidden_size)) # .to(x.device)
b_c = nn.Parameter(torch.Tensor(self.hidden_size)) # .to(x.device)
# o_t
W_o = nn.Parameter(torch.Tensor(self.input_size, self.hidden_size)) # .to(x.device)
U_o = nn.Parameter(torch.Tensor(self.hidden_size, self.hidden_size)) # .to(x.device)
b_o = nn.Parameter(torch.Tensor(self.hidden_size)) # .to(x.device)
#print(self.W_c)
layer_params = (W_i, U_i,W_f,U_f,W_c,U_c,W_o,U_o,b_i,b_f,b_c,b_o)
suffix = ''
self.param_name = ['weight_W_i{}{}', 'weight_U_i{}{}','weight_W_f{}{}','weight_U_f{}{}',
'weight_W_c{}{}','weight_U_c{}{}','weight_W_o{}{}','weight_U_o{}{}']
#if bias:
self.param_name += ['bias_b_i{}{}', 'bias_b_f{}{}','bias_b_c{}{}','bias_b_o{}{}']
self.param_name = [x.format(layer, suffix) for x in self.param_name]
for name, param in zip(self.param_name, layer_params):
setattr(self, name, param)
self.param_names.append(self.param_name)
#print(self.param_names)
#for name, param in zip(self.param_name, layer_params):
#self._all_weights[name] = param
self.init_weights()
#self.all_weights = [[getattr(self, weight) for weight in weights] for weights in self.param_names]
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
torch.nn.init.uniform_(weight, -stdv, stdv)
def init_weights(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
#def
def forward(self,
x,
init_states=None):
"""
assumes x.shape represents (batch_size, sequence_size, input_size)
"""
bs, seq_sz, _ = x.size()
hidden_seqs = 0
if init_states is None:
h_t, c_t = (
torch.zeros(bs, self.hidden_size).to(x.device),
torch.zeros(bs, self.hidden_size).to(x.device),
)
else:
h_t, c_t = init_states
for layer in range(self.num_layers):
#print(self.all_weights[0][0])
#for i in range(len(self.param_names)):
#self.weight = self.all_weights[layer]
#print(self.param_name)
#print(layer)
#print(self._all_weights)
#if self.param_names
suffix = ''
param_name = ['weight_W_i{}{}', 'weight_U_i{}{}', 'weight_W_f{}{}', 'weight_U_f{}{}',
'weight_W_c{}{}', 'weight_U_c{}{}', 'weight_W_o{}{}', 'weight_U_o{}{}']
# if bias:
param_name += ['bias_b_i{}{}', 'bias_b_f{}{}', 'bias_b_c{}{}', 'bias_b_o{}{}']
param_name = [x.format(layer, suffix) for x in param_name]
#print(param_name[0])
#W_i = self.param_name[0]
#print(W_i)
self.param_name = self.param_names[layer]
"""
print(getattr(self,self.param_name[0]))
w_i = self._all_weights[self.param_name[0]]
U_i = self._all_weights[self.param_name[1]]
W_f = self._all_weights[self.param_name[2]]
U_f = self._all_weights[self.param_name[3]]
W_c = self._all_weights[self.param_name[4]]
U_c = self._all_weights[self.param_name[5]]
W_o = self._all_weights[self.param_name[6]]
U_o = self._all_weights[self.param_name[7]]
b_i = self._all_weights[self.param_name[8]]
b_f = self._all_weights[self.param_name[9]]
b_c = self._all_weights[self.param_name[10]]
b_o = self._all_weights[self.param_name[11]]
"""
hidden_seq = []
for t in range(seq_sz):
x_t = x[:, t, :]
#print(x_t.shape,self.W_i.shape)
"""
i_t = torch.sigmoid(x_t @ self.W_i + h_t @ self.U_i + self.b_i)
f_t = torch.sigmoid(x_t @ self.W_f + h_t @ self.U_f + self.b_f)
g_t = torch.tanh(x_t @ self.W_c + h_t @ self.U_c + self.b_c)
o_t = torch.sigmoid(x_t @ self.W_o + h_t @ self.U_o + self.b_o)
"""
i_t = torch.sigmoid(x_t @ getattr(self,self.param_name[0])+ h_t @ getattr(self,self.param_name[1])+ getattr(self,self.param_name[8]))
f_t = torch.sigmoid(x_t @ getattr(self,self.param_name[2]) + h_t @ getattr(self,self.param_name[3]) + getattr(self,self.param_name[9]))
g_t = torch.tanh(x_t @ getattr(self,self.param_name[4]) + h_t @ getattr(self,self.param_name[5]) + getattr(self,self.param_name[10]))
o_t = torch.sigmoid(x_t @ getattr(self,self.param_name[6]) + h_t @ getattr(self,self.param_name[7]) + getattr(self,self.param_name[11]))
c_t = f_t * c_t + i_t * g_t
h_t = o_t * torch.tanh(c_t)
h_t = h_t[0]
hidden_seq.append(h_t.unsqueeze(1))
#print(np.array(hidden_seq).shape)
# reshape hidden_seq p/ retornar
hidden_seqs = torch.cat(hidden_seq, dim=1)
#print(hidden_seqs.shape)
x = hidden_seqs
return hidden_seqs, (h_t, c_t)