pytorch_CRF应用——huggingface的transformers模块Trainer中使用CRF

transformers的Trainer中使用CRF

  • 0. 关于CRF
  • 1. 下载一个pytorch实现的crf模块
  • 2. torchcrf的基本使用方法
  • 3. 对transformers模块进行修改
  • 4. 对torchcrf模块进行修改
  • 5. 关于评估

0. 关于CRF

条件随机场(CRF)是序列标注任务中常用的模型,其基本作用是给定一个序列的特征,对序列中每一个节点的状态进行预测,既可以单独用于序列标注任务,也可以在bert等编码器的基础上,将编码特征作为输入,可以有效地提高序列标注模型的准确性。

bert4keras中,crf层的调用非常方便,只需要一行代码即可实现,而基于pytorch的最常用的NLP模块transformers的样例中,却并没有给出CRF的使用方法。在transformers提供的Trainer中,也没有预留可以加入CRF的地方,使得用户可以方便的操作(huggingface太追求面面俱到了,原本是为了让使用者利用非常简短的代码就可以实现功能,但是在过度封装的情况下,任何一点小的改动都变得比较困难)。

关于transformers做序列标注,可以直接参考官方文档。

经过了一些实践之后,认为其提供的trainer并不适合开发者使用,更适合不需要修改代码的情况下使用,而如果有自己的模型需要搭建,我个人认为使用trainer并不是明智的选择,包括其中的各种闭环和控制操作都需要对整个模块有一定的理解,才能够熟练的使用。即便如此,我还是决定写这篇博客,希望能给有需要的同学提供一些参考。

1. 下载一个pytorch实现的crf模块

pypi上关于crf的实现模块有很多,在这里推荐Allennlp实现的版本,pytorch-crf,版本是0.7.2,其代码比较清晰,注释也很多,读起来比较容易。关于其他版本的crf实现,我没有做尝试。

下载地址:
https://pypi.org/project/pytorch-crf/#files

联网状态下可直接利用pip安装:

pip install pytorch-crf

引用方法:

from torchcrf import CRF

2. torchcrf的基本使用方法

这个模块的使用非常简便,只需要在两个位置添加。

第一步,在构造方法中创建CRF对象。需要传入一个参数num_labels,也就是标签总数(计入BIO的)。

class model(torch.nn.Model):
	def __init__(self, your_args_here, num_labels):
		# Your code here.
		self.crf = CRF(num_tags=config.num_labels, batch_first=True)

第二步,在forward方法中,计算crf损失。

	def forward(self, input_ids, attention_mask, token_type_ids, labels):
		# Your code here.
		loss = self.crf(logits, labels)
		loss = -1 * loss
		return loss

注意在这个地方loss要乘以-1,起初我在修改这部分的时候发现损失始终得不到优化,检查crf的代码才发现,计算的两项logsumexp是反着的。

3. 对transformers模块进行修改

熟悉了crf的基本使用方法,我们就可以按照自己的需求取修改transformers模块了。在序列标注任务中,通常使用的默认模型是BertForTokenClassification,所以干脆在这里边加入crf。同时我们希望保留该模型原有的功能,那么就加入一个参数use_crf,用来控制是否使用crf模块。

直接定位到源码位置,开始修改。脚本位于transformers下的models/bert下,名为modeling_bert.py。可以参考下面的路劲找到这个脚本。

vim /root/anaconda3/envs/your_env/lib/python3.6/site-packages/transformers/models/bert/modeling_bert.py

打开之后,先要在头部位置加上引用:

from torchcrf import CRF

直接搜BertForTokenClassification,找到这个类。修改**init**方法:

# 增加了一个参数use_crf来控制是否使用crf
# 如果使用,则创建crf对象

	def __init__(self, config, use_crf=False):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bert = BertModel(config, add_pooling_layer=False)
        self.use_crf = use_crf
        if self.use_crf:
            self.crf = CRF(num_tags=config.num_labels, batch_first=True)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

接下来修改它的forward方法:

# 从sequence_output = self.dropout(sequence_output)这一行开始,前面的不用改
# 其实就只是根据self.use_crf的值增加了一个分支
# 如果使用了crf则将损失替换为crf的损失,仅此而已。

		sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        loss = None
        loss_fct = CrossEntropyLoss()
        if labels is not None:
            # print(logits.shape)
            # print(labels.shape)
            if self.use_crf:
                # 如果使用crf,mask放在crf内部计算
                # print(logits.shape)
                # print(labels.shape)
                if attention_mask is not None:
                    loss = self.crf(logits, labels, mask=attention_mask.byte())
                    loss = -1 * loss
                    # print(loss)
                else:
                    loss = self.crf(logits, labels) 
                    loss = -1 * loss
            else:
                if attention_mask is not None:
                    active_loss = attention_mask.view(-1) == 1
                    active_logits = logits.view(-1, self.num_labels)
                    active_labels = torch.where(
                        active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                    )
                    loss = loss_fct(active_logits, active_labels)
                else:
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

# 后面的return也还是原样

在这一部分代码中,注意到非crf的分支对attention_mask进行了操作的,而crf的分支中却并没有这一步操作,这是因为torchcrf的代码本身有对attention_mask进行判断的操作,所以无需再这里再算一次。

现在transformers中的模型代码已经修改好了,可以通过以下的方法创建模型:

from transformers import BertForTokenClassification

model = BertForTokenClassification.from_pretrained('bert-base-chinese',
													num_labels=len(label_list),
													use_crf=True)

但是如果直接把这个model传入trainer进行训练的话,会报CUDA的错误,把device设置成cpu就会发现,其实这个错误是’list out of range’。

这是因为transformers的Trainer类会默认采用-99进行padding,这个值传入crf中之后,crf会按照label去进行索引,-99自然索引不到,所以报了list out of range。

4. 对torchcrf模块进行修改

针对3中最后提到的报错问题,我们打开crf的脚本进行简单的修改。

vim /root/anaconda3/envs/your_env/lib/python3.6/site-packages/torchcrf/__init__.py

找到def _compute_score方法,在如下的位置修改:

 # 只添加了一行tags = torch.tensor(np.maximum(tags.cpu().numpy(), 0))
 # 道理很简单,就是不让-99溢出,把它替换成0即可
 # 至于padding的mask问题不用担心,因为我们传了mask进来
 # 代码中有其他地方对mask进行了判断的
 
 	def _compute_score(
            self, emissions: torch.Tensor, tags: torch.LongTensor,
            mask: torch.ByteTensor) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # tags: (seq_length, batch_size)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and tags.dim() == 2
        assert emissions.shape[:2] == tags.shape
        assert emissions.size(2) == self.num_tags
        assert mask.shape == tags.shape
        assert mask[0].all()

        seq_length, batch_size = tags.shape
        mask = mask.float()
        # seq_ends = mask.long().sum(dim=0) - 1
        # 导致报错的原因是tags里边padding使用-100
        tags = torch.tensor(np.maximum(tags.cpu().numpy(), 0)) 

# 剩下的部分维持原样

经过这些修改,训练过程就可以正常进行了。

5. 关于评估

训练的问题解决了,那么评估呢?固然可以不进行修改,依然使用argmax的方法进行解码,但是这样一来crf就白用了,而且以CRF的损失进行优化,以argmax进行解码的话,导致两个步骤的不一致性,会令模型的效果变差。

在Trainer中,评估用到的方法叫prediction_loop,这个方法写的比较复杂,把所有batch和step中的所有结果全都堆叠在一起了,然后再重新组织(好像是这个样子,记不太清了),要修改它的话需要很多reshape的操作,比较麻烦,所以我也就没有再继续进行。

如果执意想用Trainer的话,不妨思考一下这个地方具体该怎么修改,或者在callback的控制中,控制其不进行评估,只训练,然后另写评估和预测方法。

考虑到修改源码的时间成本,我没有再继续用trainer进行训练,而是自己另写了一套训练和评估的方法。(其实通过阅读大量的论文和项目可以发现,尽管transformers模块被广泛的使用,但几乎没有人是直接采用了它的Trainer进行训练的)

我把另写的评估方法也贴出来,供看到这里的同学参考。

def evaluate(model, valid_dataloader, id2label, device):
    """
    对模型进行评估
    ---------------
    ver: 2021-08-30
    by: changhongyu
    """
    pred_true = 1e-10
    pred_all = 1e-10
    true_all = 1e-10
    
    for n, (input_ids, token_type_ids, attention_mask, true_labels) in enumerate(tqdm(valid_dataloader)):
        input_ids = input_ids.squeeze(0).to(device)
        token_type_ids = token_type_ids.squeeze(0).to(device)
        attention_mask = attention_mask.squeeze(0).to(device)
        true_labels = true_labels.squeeze(0).to(device)
        
        with torch.no_grad():
            eval_loss = model(input_ids, token_type_ids, attention_mask, true_labels)
            logits = model(input_ids, token_type_ids, attention_mask)  # (b, l, num_labels)
            # pred_labels = torch.argmax(logits, dim=-1)   # (b, l)
            pred_labels = model.crf.decode(logits)  # 主要看这一行
            
        # pred_labels = pred_labels.detach().cpu().numpy().tolist()[0]   # (l)
        pred_labels = pred_labels.detach().cpu().numpy().tolist()[0][0]   # (l)  # 还有这一行
        true_labels = true_labels.detach().cpu().numpy().tolist()
            
		# 这个方法是从labels获取entities的方法
		# 不同的数据形式对应不一样的结构
        gold_ents = get_entities_from_labels(true_labels, id2label)  
        pred_ents = get_entities_from_labels(pred_labels, id2label)
#         if n % 30 == 1:
#             print(gold_coarse)
#             print(pred_coarse)
#             print('---------')
        
        pred_all += len(pred_ents)
        true_all += len(gold_ents)
        
        for ent in pred_ents:
            if ent in gold_ents:
                pred_true += 1
                
        precision = pred_true / pred_all
        recall = pred_true / true_all
        f1 = 2 * precision * recall / (precision + recall)
        
    return f1, precision, recall

总结一下,CRF是序列标注任务中非常基础,也是非常重要的方法,即便现在的很多模型中,NER任务不再按照序列标注的形式来进行,但CRF的重要性仍然是不可忽视的,阅读一下源码对理解CRF的运作过程,也会有很大的帮助。transformers模块作为一个便捷的工具,采用trainer可以是初学者很快上手,但我认为并不适合开发者在其基础上进行大量的修改,除非已经对transformers的整个流程和逻辑已经习惯了且非常熟练,如果源码都没有从头到尾读一遍的话,还是自己设计自己的流程吧,不然遇到问题了再去debug会很麻烦。

如果这篇文章对你有帮助的话,麻烦点个赞投个币支持一下。我们下期再见。

你可能感兴趣的:(自然语言处理,pytorch,深度学习,自然语言处理)