这是一段使用百度ernie-1.0做特征提取的Bi-Lstm+crf的代码:
class ERNIE_LSTM_CRF(nn.Module):
"""
ernie_lstm_crf model
"""
def __init__(self, ernie_config, tagset_size, embedding_dim, hidden_dim, rnn_layers, dropout_ratio, dropout1, use_cuda=False):
super(ERNIE_LSTM_CRF, self).__init__()
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
##加载ERNIE
self.word_embeds = AutoModel.from_pretrained(ernie_config)
# self.word_embeds = ErnieModel.from_pretrained(ernie_config, from_hf_hub=False)
self.lstm = nn.LSTM(embedding_dim, hidden_dim,
num_layers=rnn_layers, bidirectional=True,
dropout=dropout_ratio, batch_first=True)
self.rnn_layers = rnn_layers
self.dropout1 = nn.Dropout(p=dropout1)
self.crf = CRF(num_tags=tagset_size, batch_first=True)
self.liner = nn.Linear(hidden_dim*2, tagset_size)
self.tagset_size = tagset_size
def rand_init_hidden(self, batch_size):
"""
random initialize hidden variable
"""
return torch.randn(2 * self.rnn_layers, batch_size, self.hidden_dim), \
torch.randn(2 * self.rnn_layers, batch_size, self.hidden_dim)
def forward(self, sentence, attention_mask=None):
'''
args:
sentence (batch_size, word_seq_len) : word-level representation of sentence
hidden: initial hidden state
return:
crf input (batch_size, word_seq_len, tag_size), hidden
'''
batch_size = sentence.size(0)
seq_length = sentence.size(1)
embeds = self.word_embeds(sentence, attention_mask=attention_mask)
hidden = self.rand_init_hidden(batch_size)
if embeds[0].is_cuda:
hidden = tuple(i.cuda() for i in hidden)
lstm_out, hidden = self.lstm(embeds[0], hidden)
lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim*2)
d_lstm_out = self.dropout1(lstm_out)
l_out = self.liner(d_lstm_out)
lstm_feats = l_out.contiguous().view(batch_size, seq_length, -1)
return lstm_feats
def loss(self, feats, mask, tags):
"""
feats: size=(batch_size, seq_len, tag_size)
mask: size=(batch_size, seq_len)
tags: size=(batch_size, seq_len)
:return:
"""
loss_value = -self.crf(feats, tags, mask) # 计算损失
batch_size = feats.size(0)
loss_value /= float(batch_size)
return loss_value
crf在bi-lstm+crf中有两处使用,一个是解码推理,一个是损失计算。
# 此处使用维特比算法解码 best_path = model.crf.decode(feats, masks.byte()) loss = model.loss(feats, masks.byte(), tags)
loss_value = -seld.crf(feats,tags,mask)
在计算loss的部分使用了crf是为了有效利用序列信息,提高模型的预测准确性。因为双向长短期记忆网络(BiLSTM)能够捕获给定标记的前后输入特征的上下文信息。然而,单独使用BiLSTM并没有建模序列标签之间的约束。CRF层考虑了序列中标签之间的依赖性,使得预测出的标签序列更加全局优化且有效。
CRF层的损失函数通常涉及计算给定输入序列的正确标签序列的负对数似然。通过最小化这个损失,模型学习增加正确标签序列的概率,同时降低错误序列的概率。使用mask
允许模型处理变长序列,通过忽略填充标记来实现,这对于批处理训练至关重要。