'''
# rnn 和 lstm 在定义上差不太多
# lstm在输入的时候可以选择是不是输入h_0和c_0
rnn = nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))
'''
class Classfication_Model(nn.Module):
def __init__(self):
super(Classfication_Model, self).__init__()
self.hidden_size = 128
self.embedding_dim = 200
self.number_layer = 4
self.bidirectional = True
self.bi_number = 2 if self.bidirectional else 1
self.dropout = 0.5
self.embedding = nn.Embedding(num_embeddings=len(model.index_to_key)+200
, embedding_dim=self.embedding_dim)
self.lstm = nn.LSTM(input_size=self.embedding_dim
, hidden_size=self.hidden_size
, num_layers=self.number_layer
, dropout=self.dropout
, bidirectional=self.bidirectional)
self.fc = nn.Sequential(
nn.Linear(self.hidden_size*self.bi_number,20)
, nn.ReLU()
, nn.Linear(20,2)
)
def init_hidden_state(self, batch_size):
h_0 = torch.rand(batch_size, self.number_layer * self.bi_number, self.hidden_size).to(device)
c_0 = torch.rand(batch_size, self.number_layer * self.bi_number, self.hidden_size).to(device)
return (h_0, c_0)
def forward(self, input, hidden):
input_embeded = self.embedding(input)
input_embeded = input_embeded.permute(1, 0, 2)
hidden = [x.permute(1,0,2).contiguous() for x in hidden]
_, (h_n, c_n) = self.lstm(input_embeded, hidden)
out = torch.cat((h_n[-2, :, :], h_n[-1, :, :]), -1)
out = self.fc(out)
return out
def train(epoch):
ds = corpus_dataset(train_model=True, max_sentence_length=50,train_set=train_set,test_set=test_set)
train_dataloader = DataLoader(ds, batch, shuffle=True,num_workers=5)
total_loss = 0
classfication_model.train()
for idx, (input, target) in enumerate(train_dataloader):
target = target.to(device)
input = input.to(device)
optimizer.zero_grad()
hidden = classfication_model.module.init_hidden_state(len(input))
output = classfication_model(input, hidden)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"epoch:{epoch} ###### total_loss:{total_loss:.6f}")