Flat Lattice 代码

NLP项目实践——中文序列标注Flat Lattice代码解读、运行与使用_yangjie_word_char_mix.txt_常鸿宇的博客-CSDN博客

论文阅读《FLAT:Chinese NER Using Flat-Lattice Transformer》_flat论文_LawsonAbs的博客-CSDN博客 Flat-Lattice-Transformer模型源码测试_ontonote4ner_Dongxue_NLP的博客-CSDN博客

运行 python flat_main.py --dataset clue

1.根据已有代码,添加了

elif args.dataset == 'clue':
    datasets, vocabs, embeddings = load_clue2020(clue_2020_ner_path, yangjie_rich_pretrain_unigram_path,
                                                             yangjie_rich_pretrain_bigram_path,
                                                             _refresh=refresh_data, index_token=False,
                                                             train_clip=args.train_clip,
                                                             _cache_fp=raw_dataset_cache_name,
                                                             char_min_freq=args.char_min_freq,
                                                             bigram_min_freq=args.bigram_min_freq,
                                                             only_train_min_freq=args.only_train_min_freq
                                                             )

2.gbk问题就写encoding=‘utf-8’

3.invalid argument:因为cacahe文件夹里要读取的一个文件名里有冒号,所以在flat_main.py下

raw_dataset_cache_name = os.path.join('cache', args.dataset +
                                      '_trainClip_{} '.format(args.train_clip)
                                      + 'bgminfreq_{} '.format(args.bigram_min_freq)
                                      + 'char_min_freq_{} '.format(args.char_min_freq)
                                      + 'word_min_freq_{} '.format(args.word_min_freq)
                                      + 'oy_train_m_fq{} '.format(args.only_train_min_freq)
                                      + 'number_norm{} '.format(args.number_normalized)
                                      + 'load_dataset_seed{} '.format(load_dataset_seed))

中间略过

cache_name = os.path.join('cache', (args.dataset + '_lattice' + '_only_train_{} ' +
                                    '_trainClip_{} ' + '_norm_num:{} '
                                    + 'char_min_freq{} ' + 'bigram_min_freq{} ' + 'word_min_freq{} ' + 'oy_train_m_fq{}'
                                    + 'number_norm{} ' + 'lexicon_{} ' + 'load_dataset_seed_{} ')
                          .format(args.only_lexicon_in_train,
                                  args.train_clip, args.number_normalized, args.char_min_freq,
                                  args.bigram_min_freq, args.word_min_freq, args.only_train_min_freq,
                                  args.number_normalized, args.lexicon_name, load_dataset_seed))
print('cache_name 名字鹿丸')

这里把文件命名改了,冒号全去掉

4.若出现enconding type的错误,注意看自己用的什么数据集,比如clue,就改成bio。

NER 原理及 TENER 模型搭建_tener模型_xiedelong的博客-CSDN博客

5.flat_main.py 因为默认设置就是ues bert == 1,所以这里把

Flat Lattice 代码_第1张图片

 fastNLP_module_V1 会引用一个BertModel,会出错,之后再看

最后clue数据集的结果

Flat Lattice 代码_第2张图片

注意:这里test的f1,P,R为0是因为test测试集有问题,全标的o

注意:很有可能每个clue的标注方式是不一样的,有的是bio,有的是bmeso,如果先读取的是bio的,并且存储到了cacheli,你想读取bmeso的clue数据集时,需要删掉cache里的数据,不然就算代码里改成bmeso也没用

2.MSRA数据集

直接参数设置为msra就行

1.原代码的load_msra_without_dev没有这个函数,用load_msra_ner_1,注意其中的train_path改成train_dev.char.bmes。不然会报错

 
  

2.encodingtype改成bioes

3.因为运行msra数据集时,有些句子长度大于512,所以设置了

auto_truncate=True

 4项目内的msra数据集有问题,句子过长,参考最上面的博客重新下载,有从bio转成bmeso的代码。

3.RESUME数据集

1.encodingtype是 ‘bmeso’

这是在服务器上跑出来的结果

Flat Lattice 代码_第3张图片

也是服务器上跑,按论文上的最佳参数,基本上和论文结果相同了,还要好一点

python flat_main.py --dataset resume --epoch 30 --batch 10 --weight_decay 0.05 --head 12 --warmup 5 

 Flat Lattice 代码_第4张图片

4.weibo数据集

在网上下载weibo数据集后,要将原始数据中的seg信息去掉。

建立一个py文件,运行即可(注意use_bert默认设置为TRUE,要和论文中的bert+flat相比较

#!usr/bin/env python
# encoding: utf-8

from paths import weibo_ner_path
import os

def deseg_weibo(weibopath):
    train_path = os.path.join(weibopath, 'weiboNER_2nd_conll.train')
    dev_path = os.path.join(weibopath, 'weiboNER_2nd_conll.dev')
    test_path = os.path.join(weibopath, 'weiboNER_2nd_conll.test')

    for data_file in [train_path, dev_path, test_path]:
        output_file = data_file + "_deseg"
        f_out = open(output_file, "w", encoding='utf8')
        with open(data_file, "r", encoding='utf8') as f:
            for line in f.readlines():
                line = line.strip()
                if line != "":
                    span_list = line.split('\t')
                    raw_char = ''.join(list(span_list[0])[:-1])
                    tag = span_list[-1]
                    f_out.write(' '.join([raw_char, tag]) + '\n')
                else:
                    f_out.write('\n')

if __name__ == '__main__':
    deseg_weibo(weibo_ner_path)
    print('- Done!')

结果差了论文结果一点,应该是超参数调优的问题

Flat Lattice 代码_第5张图片

按照论文上的超参数,效果比论文里还是差一点,可能是因为head 换成12了

--dataset
weibo
--epoch
30
--batch
4
--lr
1e-3
--weight_decay
0.05
--warmup
5
--head
12 Flat Lattice 代码_第6张图片

把head改回8,效果,原因再探究

Flat Lattice 代码_第7张图片

 batch_size换到8跑不动,要在服务器上跑

05.数据输出流程代码(以weibo为例)

(1)从trainer.train()进入,然后self._train()

(2)然后到trainer.py中的_do_validation()函数,用来得到验证和测试数据          

(3) 进入res = self.tester.test(),这里面计算了各个评价指标,时间还有最佳效果的判定

    def test(self):
        r"""开始进行验证,并返回验证结果。

        :return Dict[Dict]: dict的二层嵌套结构,dict的第一层是metric的名称; 第二层是这个metric的指标。一个AccuracyMetric的例子为{'AccuracyMetric': {'acc': 1.0}}。
        """
        # turn on the testing mode; clean up the history
        self._model_device = _get_model_device(self._model)
        network = self._model
        self._mode(network, is_test=True)
        data_iterator = self.data_iterator
        eval_results = {}
        try:
            with torch.no_grad():
                if not self.use_tqdm:
                    from .utils import _pseudo_tqdm as inner_tqdm
                else:
                    inner_tqdm = tqdm
                with inner_tqdm(total=len(data_iterator), leave=False, dynamic_ncols=True) as pbar:
                    pbar.set_description_str(desc="Test")

                    start_time = time.time()

                    for batch_x, batch_y in data_iterator:
                        _move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
                        pred_dict = self._data_forward(self._predict_func, batch_x)
                        if not isinstance(pred_dict, dict):
                            raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} "
                                            f"must be `dict`, got {type(pred_dict)}.")
                        for metric in self.metrics:
                            metric(pred_dict, batch_y)

                        if self.use_tqdm:
                            pbar.update()

                    for metric in self.metrics:
                        eval_result = metric.get_metric() #得到计算结果
                        if not isinstance(eval_result, dict):
                            raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be "
                                            f"`dict`, got {type(eval_result)}")
                        metric_name = metric.get_metric_name() #评价指标名字
                        eval_results[metric_name] = eval_result #写入列表
                    pbar.close()
                    end_time = time.time()
                    test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!'
                    if self.verbose >= 0:
                        self.logger.info(test_str)
        except _CheckError as e:
            prev_func_signature = _get_func_signature(self._predict_func)
            _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
                                 check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
                                 dataset=self.data, check_level=0)
        
        if self.verbose >= 1:
            logger.info("[tester] \n{}".format(self._format_eval_results(eval_results)))
        self._mode(network, is_test=False) #传入参数,是否是test模式
        return eval_results

 然后返回到_do_validation()函数

注意:这里先得到的是dev的结果

on_valid_end()函数计算并输出了test的结果
    def _do_validation(self, epoch, step):
        self.callback_manager.on_valid_begin()
        res = self.tester.test() #评价指标的计算

        #判断是否是最佳
        is_better_eval = False
        if self._better_eval_result(res):
            if self.save_path is not None:
                self._save_model(self.model,
                                 "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]))
            elif self._load_best_model:
                self._best_model_states = {name: param.cpu().clone() for name, param in self.model.state_dict().items()}
            self.best_dev_perf = res
            self.best_dev_epoch = epoch
            self.best_dev_step = step
            is_better_eval = True
        # get validation results; adjust optimizer
        self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) #计算得到test数据并输出
        return res

on_valid_end()函数

    def on_valid_end(self, eval_result, metric_key, optimizer, better_result):
        if better_result:
            eval_result = deepcopy(eval_result)
            eval_result['step'] = self.step
            eval_result['epoch'] = self.epoch
            fitlog.add_best_metric(eval_result)
        fitlog.add_metric(eval_result, step=self.step, epoch=self.epoch) #这里应该是把dev的结果传到metric里去了
        if len(self.testers) > 0:
            for key, tester in self.testers.items():
                try:
                    eval_result = tester.test() #得到test的数据
                    if self.verbose != 0:
                        self.pbar.write("FitlogCallback evaluation on {}:".format(key))
                        self.pbar.write(tester._format_eval_results(eval_result))
                    fitlog.add_metric(eval_result, name=key, step=self.step, epoch=self.epoch)
                    if better_result:
                        fitlog.add_best_metric(eval_result, name=key)
                except Exception:
                    self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key))
    

key先为data-test,tester调用test()函数得到test的评价指标值

然后输出到屏幕上,返回到train()函数

最后wrapper()函数取出dev数据,再输出dev的评价指标值

你可能感兴趣的:(python学习,python,开发语言)