小黑啃fastNLP:自定义metrics解决分类问题

1.继承ClassfierMetric

from fastNLP.core.metrics import MetricBase
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score, classification_report
class ClassfierMetric(MetricBase):
    
    def __init__(self,pred = None,target = None,seq_len = None):
        """
        :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
        :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
        :param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len`
        """
        super().__init__()
        self._init_param_map(pred = pred,target = target,seq_len = seq_len)
        
        self.total = 0
        self.total_pred = []
        self.total_target = []
    
    def evaluate(self,pred,target,seq_len = None):
        """
        evaluate函数将针对一个批次的预测结果做评价指标的累计

        :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
                torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes])
        :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),
                torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len])
        :param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]).
                如果mask也被传进来的话seq_len会被忽略.
        """
        # pred与target的形式多种多样
        if not isinstance(pred, torch.Tensor):
            raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
                            f"got {type(pred)}.")
        if not isinstance(target, torch.Tensor):
            raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
                            f"got {type(target)}.")
        if seq_len is not None and not isinstance(seq_len,torch.Tensor):
            raise TypeError(f"'seq_lens' in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
                            f"got {type(seq_len)}."
                           )
        
        if seq_len is not None and target.dim() > 1:
            max_len = target.size(1)
            masks = seq_len_to_mask(seq_len = seq_len,max_len = max_len)
        else:
            masks = None
        
        if pred.dim() == target.dim():
            pass
        elif pred.dim() == target.dim() + 1:
            pred = pred.argmax(dim = -1)
            if seq_len is None and target.dim() > 1:
                warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")
        else:
            raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "
                               f"size:{pred.size()}, target should have size: {pred.size()} or "
                               f"{pred.size()[:-1]}, got {target.size()}.")
        target = target.to(pred)
        if masks is not None:
            self.total += torch.sum(masks).item()
        else:
            self.total += np.prod(list(pred.size()))
        
        
        self.total_pred.extend(pred.cpu().tolist())
        self.total_target.extend(target.cpu().tolist())
    def get_metric(self,reset = True):
        p = precision_score(self.total_target,self.total_pred,average='micro')
        r = recall_score(self.total_target,self.total_pred,average='micro')
        f = f1_score(self.total_target,self.total_pred,average='micro')
        print(classification_report(self.total_target,self.total_pred))
        return {'P':p,'R':r,'F':f}

2.核心代码

from fastNLP import DataSet, Instance, Vocabulary
from fastNLP.io import ChnSentiCorpPipe, DataBundle
from tqdm import tqdm
from fastNLP.embeddings import StaticEmbedding, StackEmbedding,BertEmbedding,LSTMCharEmbedding
from fastNLP import Trainer, CrossEntropyLoss,BCELoss
import torch
from sklearn.metrics import f1_score, precision_score, recall_score, classification_report
from fastNLP.models import BertForSequenceClassification,CNNText

def get_data(file):
    lines = [line for line in open(file, 'r', encoding='utf-8')]
    dataset = DataSet()

    for line in tqdm(lines):
        assert len(line.split('\t')) == 3
        # 读取内容
        text = line.split('\t')[1].strip()
        label = line.split('\t')[2].strip()

        # 构造fastNLP框架
        instance = Instance(words=list(text), target=label)
        dataset.append(instance)

    return dataset

def merge_data(train_set, test_set):
    target_vocab = Vocabulary(padding=None, unknown=None)
    target_vocab.from_dataset(train_set, test_set, field_name='target')
    target_vocab.index_dataset(train_set, test_set, field_name='target')

    char_vocab = Vocabulary()
    char_vocab.from_dataset(train_set, test_set, field_name='words')
    char_vocab.index_dataset(train_set, test_set, field_name='words')

    return train_set, test_set, target_vocab, char_vocab


def evaluate(test_set, model):
    probs = []
    labels = []
    for i in tqdm(range(len(test_set))):
        word_ids = test_set.words[i]
        word_ids = torch.LongTensor(word_ids)
        pred = model.predict(word_ids.view(1, -1))
        prob = pred['pred'].numpy()[0]
        target = test_set.target[i]
        probs.append(prob)
        labels.append(target)
    p = precision_score(labels, probs, average='micro')
    r = recall_score(labels, probs, average='micro')
    f1 = f1_score(labels, probs, average='micro')
    print(classification_report(labels, probs))
    print('P:', p)
    print('R:', r)
    print('F1:', f1)


train_set = get_data('./classify_data3/train.txt')
test_set = get_data('./classify_data3/test.txt')
train_set, test_set, target_vocab, char_vocab = merge_data(train_set, test_set)


fastnlp_embed = StaticEmbedding(char_vocab, model_dir_or_name='cn-char-fastnlp-100d',min_freq=2)
model_CNN = CNNText(fastnlp_embed, num_classes=27,dropout=0.1)
model_CNN.load_state_dict(torch.load('model_CNN.pth')['net'])
train_set.set_target('target')
train_set.set_input('words')

test_set.set_target('target')
test_set.set_input('words')


from fastNLP import Tester
tester = Tester(data=test_set,model = model_CNN,metrics=ClassfierMetric(),device='cpu')
tester.test()

输出:

607 out of 6280 words have frequency less than 2.
Found 5579 out of 5673 words in the pre-training embedding.
precision recall f1-score support

       0       0.53      0.78      0.63      4440
       1       0.50      0.48      0.49      3511
       2       0.62      0.57      0.59      3016
       3       0.66      0.57      0.61      2722
       4       0.56      0.56      0.56      2683
       5       0.65      0.63      0.64      1303
       6       0.54      0.62      0.58      1268
       7       0.64      0.76      0.70      1199
       8       0.87      0.60      0.71      1086
       9       0.88      0.67      0.76      1018
      10       0.58      0.43      0.49       896
      11       0.58      0.43      0.50       850
      12       0.70      0.61      0.65       496
      13       0.52      0.57      0.54       460
      14       0.48      0.08      0.14       365
      15       0.79      0.75      0.77       361
      16       0.55      0.55      0.55       352
      17       0.61      0.53      0.57       322
      18       0.69      0.44      0.54       295
      19       0.81      0.75      0.78       262
      20       0.64      0.35      0.45       133
      21       0.85      0.29      0.43        76
      22       0.85      0.62      0.72        37
      23       0.00      0.00      0.00        32
      24       0.20      0.08      0.11        25

accuracy                           0.59     27208

macro avg 0.61 0.51 0.54 27208
weighted avg 0.60 0.59 0.59 27208

Evaluate data in 5.79 seconds!
[tester]
ClassfierMetric: P=0.5936489267862394, R=0.5936489267862394, F=0.5936489267862394

你可能感兴趣的:(fastNLP真是个好东西,分类,pytorch,python)