【学习笔记】Pytorch LSTM/RNN 代码

'''
# rnn 和 lstm 在定义上差不太多
# lstm在输入的时候可以选择是不是输入h_0和c_0

rnn = nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))
'''
# 这里是一段lstm的使用的代码
class Classfication_Model(nn.Module):
    def __init__(self):
        super(Classfication_Model, self).__init__()
        self.hidden_size = 128
        self.embedding_dim = 200
        self.number_layer = 4
        self.bidirectional = True
        self.bi_number = 2 if self.bidirectional else 1
        self.dropout = 0.5
        self.embedding = nn.Embedding(num_embeddings=len(model.index_to_key)+200
                                       , embedding_dim=self.embedding_dim)

        self.lstm = nn.LSTM(input_size=self.embedding_dim
                            , hidden_size=self.hidden_size
                            , num_layers=self.number_layer
                            , dropout=self.dropout
                            , bidirectional=self.bidirectional)
        self.fc = nn.Sequential(
            nn.Linear(self.hidden_size*self.bi_number,20)
            , nn.ReLU()
            , nn.Linear(20,2)
        )

    def init_hidden_state(self, batch_size):
        h_0 = torch.rand(batch_size, self.number_layer * self.bi_number,  self.hidden_size).to(device)
        c_0 = torch.rand(batch_size, self.number_layer * self.bi_number, self.hidden_size).to(device)
        return (h_0, c_0)

    def forward(self, input, hidden):
        input_embeded = self.embedding(input)
        input_embeded = input_embeded.permute(1, 0, 2) # 调整为:[sqe_len,batch_size,embedding_dim]
        hidden = [x.permute(1,0,2).contiguous() for x in hidden]
        _, (h_n, c_n) = self.lstm(input_embeded, hidden)
        out = torch.cat((h_n[-2, :, :], h_n[-1, :, :]), -1)# 2,256
        out = self.fc(out)
        return out
 
def train(epoch):
    ds = corpus_dataset(train_model=True, max_sentence_length=50,train_set=train_set,test_set=test_set)
    train_dataloader = DataLoader(ds, batch, shuffle=True,num_workers=5)
    total_loss = 0
    classfication_model.train()
    # hidden = classfication_model.init_hidden_state(batch) DataParallel时出错
    # hidden = classfication_model.module.init_hidden_state(batch) 这个batch_size设置是死的
    for idx, (input, target) in enumerate(train_dataloader):
        target = target.to(device)
        input = input.to(device)
        optimizer.zero_grad()
        # 进行初始化获得h_0与c_0
        # 这是是在每个样本中都会进行
        hidden = classfication_model.module.init_hidden_state(len(input))# 这个batch_size设置是活的
        output = classfication_model(input, hidden)
        loss = criterion(output, target)  # traget需要是[0,9],不能是[1-10]
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"epoch:{epoch}  ######  total_loss:{total_loss:.6f}")

你可能感兴趣的:(我的科研之路~,我学习过的知识,python)