二十八、基于TextCNN的中文文本分类四

1. 模型的训练和评估

1.1 模型预测的流程

  • 模型预测的流程包括对文本预处理
  • 构建预测数据迭代器
  • 调用模型完成预测

1.2 TextCNN文本分类流程

  1. 准备数据:从THUCNews中抽取了20万条新闻标题,共10个预测类别

  2. 数据预处理:构建词汇表、文本向量化、按批次读取数据

  3. 模型构建:输入层->Embeding层->全连接层->输出层

  4. 模型的训练、评估和预测

1.3 代码实现

  • 步骤一:使用测试数据评估模型predict_eval.py
# coding: UTF-8
# coding:utf-8
import torch
import numpy as np
from unit28.train_eval import evaluate

MAX_VOCAB_SIZE = 10000
UNK, PAD = '', ''

tokenizer = lambda x: [y for y in x]  # char-level

def test(config, model, test_iter):
    # test
    model.load_state_dict(torch.load(config.save_path)) # 加载训练好的的模型
    model.eval()  # 开启评价模式

    test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
    msg = 'Test Loss: {0:>5.2},  Test Acc: {1:>6.2%}'
    print(msg.format(test_loss, test_acc))
    print("Precision, Recall and F1-Score...")
    print(test_report)
    print("Confusion Matrix...")
    print(test_confusion)

  • 步骤二:加载待分类数据predict_eval.py
def load_dataset(text, vocab, config, pad_size=32):
    contents = []
    for line in text:
        lin = line.strip()
        if not lin:
            continue
        words_line = []
        token = tokenizer(line)
        seq_len = len(token)
        if pad_size:
            if len(token) < pad_size:
                token.extend([PAD] * (pad_size - len(token)))
            else:
                token = token[:pad_size]
                seq_len = pad_size
        # word to id
        for word in token:
            words_line.append(vocab.get(word, vocab.get(UNK)))
        contents.append((words_line, int(0), seq_len))
    return contents  # [([...], 0), ([...], 1), ...]
  • 步骤三:加载训练好的模型进行预测predict_eval.py
def match_label(pred, config):
    label_list = config.class_list
    return label_list[pred]


def final_predict(config, model, data_iter):
    map_location = lambda storage, loc: storage
    model.load_state_dict(torch.load(config.save_path, map_location=map_location))
    model.eval()
    predict_all = np.array([])
    with torch.no_grad():
        for texts, _ in data_iter:
            outputs = model(texts)
            pred = torch.max(outputs.data, 1)[1].cpu().numpy()
            pred_label = [match_label(i, config) for i in pred]
            predict_all = np.append(predict_all, pred_label)
    return predict_all
  • 步骤四:主函数run.py
# coding:utf-8

from unit28.TextCNN import Config
from unit28.TextCNN import Model
from unit28.load_data import build_dataset
from unit28.load_data_iter import build_iterator
from unit28.predict_eval import test,load_dataset,final_predict

text = ['国考28日网上查报名序号查询后务必牢记报名参加2011年国家公务员的考生,如果您已通过资格审查,那么请于10月28日8:00后,登录考录专题网站查询自己的“关键数字”——报名序号。'
            '国家公务员局等部门提醒:报名序号是报考人员报名确认和下载打印准考证等事项的重要依据和关键字,请务必牢记。此外,由于年龄在35周岁以上、40周岁以下的应届毕业硕士研究生和'
            '博士研究生(非在职),不通过网络进行报名,所以,这类人报名须直接与要报考的招录机关联系,通过电话传真或发送电子邮件等方式报名。',
            '高品质低价格东芝L315双核本3999元作者:徐彬【北京行情】2月20日东芝SatelliteL300(参数图片文章评论)采用14.1英寸WXGA宽屏幕设计,配备了IntelPentiumDual-CoreT2390'
            '双核处理器(1.86GHz主频/1MB二级缓存/533MHz前端总线)、IntelGM965芯片组、1GBDDR2内存、120GB硬盘、DVD刻录光驱和IntelGMAX3100集成显卡。目前,它的经销商报价为3999元。',
            '国安少帅曾两度出山救危局他已托起京师一代才俊新浪体育讯随着联赛中的连续不胜,卫冕冠军北京国安的队员心里到了崩溃的边缘,俱乐部董事会连夜开会做出了更换主教练洪元硕的决定。'
            '而接替洪元硕的,正是上赛季在李章洙下课风波中同样下课的国安俱乐部副总魏克兴。生于1963年的魏克兴球员时代并没有特别辉煌的履历,但也绝对称得上特别:15岁在北京青年队获青年'
            '联赛最佳射手,22岁进入国家队,著名的5-19一战中,他是国家队的替补队员。',
            '汤盈盈撞人心情未平复眼泛泪光拒谈悔意(附图)新浪娱乐讯汤盈盈日前醉驾撞车伤人被捕,',
            '甲醇期货今日挂牌上市继上半年焦炭、铅期货上市后,酝酿已久的甲醇期货将在今日正式挂牌交易。基准价均为3050元/吨继上半年焦炭、铅期货上市后,酝酿已久的甲醇期货将在今日正式'
            '挂牌交易。郑州商品交易所(郑商所)昨日公布首批甲醇期货8合约的上市挂牌基准价,均为3050元/吨。据此推算,买卖一手甲醇合约至少需要12200元。业内人士认为,作为国际市场上的'
            '首个甲醇期货品种,其今日挂牌后可能会因炒新资金追捧而出现冲高走势,脉冲式行情过后可能有所回落,不过,投资者在上市初期应关注期现价差异常带来的无风险套利交易机会。',
            '佟丽娅穿白色羽毛长裙美翻,自曝跳舞的女孩能吃苦',
            '江欣燕透露汤盈盈钱嘉乐分手 用冷笑话补救']

if __name__ == "__main__":
    config = Config()
    print("Loading data...")
    vocab, train_data, dev_data, test_data = build_dataset(config, False)
    # 1. 批量加载测试数据
    test_iter = build_iterator(test_data,config, False)
    config.n_vocab = len(vocab)
    # 2. 加载模型结构
    model = Model(config).to(config.device)
    # 3. 测试
    test(config, model, test_iter)

    print("+++++++++++++++++")

    # 4. 预测

    content = load_dataset(text, vocab, config)
    predict_iter = build_iterator(content, config, predict=True)

    result = final_predict(config, model, predict_iter)
    for i, j in enumerate(result):
        print('text:{}'.format(text[i]), '\t', 'label:{}'.format(j))

1.4 运行结果

运行结果:

D:\Users\tarena\PycharmProjects\nlp\venv\Scripts\python.exe D:/Users/tarena/PycharmProjects/nlp/unit28/run.py
Loading data...
Vocab size: 4762
180000it [00:02, 71001.39it/s]
10000it [00:00, 49459.62it/s]
10000it [00:00, 80214.20it/s]
Test Loss:  0.43,  Test Acc: 86.52%
Precision, Recall and F1-Score...
              precision    recall  f1-score   support

          财经     0.8903    0.8520    0.8707      1000
          房产     0.9414    0.8510    0.8939      1000
          股票     0.8416    0.7650    0.8015      1000
          教育     0.9266    0.9470    0.9367      1000
          科技     0.7047    0.8710    0.7791      1000
          社会     0.8651    0.8660    0.8656      1000
          时政     0.7977    0.8870    0.8400      1000
          体育     0.8968    0.9390    0.9174      1000
          游戏     0.9573    0.8070    0.8757      1000
          娱乐     0.8947    0.8670    0.8807      1000

    accuracy                         0.8652     10000
   macro avg     0.8716    0.8652    0.8661     10000
weighted avg     0.8716    0.8652    0.8661     10000

Confusion Matrix...
[[852   9  58   5  28   9  24  10   1   4]
 [ 21 851  26   8  28  20  19   7   2  18]
 [ 64  20 765   4  71   0  63   7   3   3]
 [  1   0   3 947   8  10  11   7   2  11]
 [  3   6  25   8 871  22  31   6  12  16]
 [  5  12   2  20  25 866  47   6   2  15]
 [  6   1  18  12  34  28 887  10   0   4]
 [  1   0   2   2  17   8  14 939   1  16]
 [  1   1   6   4 123   7  11  25 807  15]
 [  3   4   4  12  31  31   5  30  13 867]]
+++++++++++++++++
text:国考28日网上查报名序号查询后务必牢记报名参加2011年国家公务员的考生,如果您已通过资格审查,那么请于1028800后,登录考录专题网站查询自己的“关键数字”——报名序号。国家公务员局等部门提醒:报名序号是报考人员报名确认和下载打印准考证等事项的重要依据和关键字,请务必牢记。此外,由于年龄在35周岁以上、40周岁以下的应届毕业硕士研究生和博士研究生(非在职),不通过网络进行报名,所以,这类人报名须直接与要报考的招录机关联系,通过电话传真或发送电子邮件等方式报名。 	 label:教育
text:高品质低价格东芝L315双核本3999元作者:徐彬【北京行情】220日东芝SatelliteL300(参数图片文章评论)采用14.1英寸WXGA宽屏幕设计,配备了IntelPentiumDual-CoreT2390双核处理器(1.86GHz主频/1MB二级缓存/533MHz前端总线)、IntelGM965芯片组、1GBDDR2内存、120GB硬盘、DVD刻录光驱和IntelGMAX3100集成显卡。目前,它的经销商报价为3999元。 	 label:科技
text:国安少帅曾两度出山救危局他已托起京师一代才俊新浪体育讯随着联赛中的连续不胜,卫冕冠军北京国安的队员心里到了崩溃的边缘,俱乐部董事会连夜开会做出了更换主教练洪元硕的决定。而接替洪元硕的,正是上赛季在李章洙下课风波中同样下课的国安俱乐部副总魏克兴。生于1963年的魏克兴球员时代并没有特别辉煌的履历,但也绝对称得上特别:15岁在北京青年队获青年联赛最佳射手,22岁进入国家队,著名的5-19一战中,他是国家队的替补队员。 	 label:体育
text:汤盈盈撞人心情未平复眼泛泪光拒谈悔意(附图)新浪娱乐讯汤盈盈日前醉驾撞车伤人被捕, 	 label:娱乐
text:甲醇期货今日挂牌上市继上半年焦炭、铅期货上市后,酝酿已久的甲醇期货将在今日正式挂牌交易。基准价均为3050元/吨继上半年焦炭、铅期货上市后,酝酿已久的甲醇期货将在今日正式挂牌交易。郑州商品交易所(郑商所)昨日公布首批甲醇期货8合约的上市挂牌基准价,均为3050元/吨。据此推算,买卖一手甲醇合约至少需要12200元。业内人士认为,作为国际市场上的首个甲醇期货品种,其今日挂牌后可能会因炒新资金追捧而出现冲高走势,脉冲式行情过后可能有所回落,不过,投资者在上市初期应关注期现价差异常带来的无风险套利交易机会。 	 label:财经
text:佟丽娅穿白色羽毛长裙美翻,自曝跳舞的女孩能吃苦 	 label:娱乐
text:江欣燕透露汤盈盈钱嘉乐分手 用冷笑话补救 	 label:娱乐

Process finished with exit code 0

1.6 完整代码

"""predict_eval.py"""

# coding:utf-8
import torch
import numpy as np
from unit28.train_eval import evaluate


MAX_VOCAB_SIZE = 10000
UNK, PAD = '', ''

tokenizer = lambda x: [y for y in x]  # char-level


def test(config, model, test_iter):
    # test
    model.load_state_dict(torch.load(config.save_path)) # 加载训练好的的模型
    model.eval()  # 开启评价模式

    test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
    msg = 'Test Loss: {0:>5.2},  Test Acc: {1:>6.2%}'
    print(msg.format(test_loss, test_acc))
    print("Precision, Recall and F1-Score...")
    print(test_report)
    print("Confusion Matrix...")
    print(test_confusion)


def load_dataset(text, vocab, config, pad_size=32):
    contents = []
    for line in text:
        lin = line.strip()
        if not lin:
            continue
        words_line = []
        token = tokenizer(line)
        seq_len = len(token)
        if pad_size:
            if len(token) < pad_size:
                token.extend([PAD] * (pad_size - len(token)))
            else:
                token = token[:pad_size]
                seq_len = pad_size
        # word to id
        for word in token:
            words_line.append(vocab.get(word, vocab.get(UNK)))
        contents.append((words_line, int(0), seq_len))
    return contents  # [([...], 0), ([...], 1), ...]


def match_label(pred, config):
    label_list = config.class_list
    return label_list[pred]


def final_predict(config, model, data_iter):
    map_location = lambda storage, loc: storage
    model.load_state_dict(torch.load(config.save_path, map_location=map_location))
    model.eval()
    predict_all = np.array([])
    with torch.no_grad():
        for texts, _ in data_iter:
            outputs = model(texts)
            pred = torch.max(outputs.data, 1)[1].cpu().numpy()
            pred_label = [match_label(i, config) for i in pred]
            predict_all = np.append(predict_all, pred_label)
    return predict_all

"""run.py"""
# coding:utf-8

from unit27.TextCNN import Config
from unit27.TextCNN import Model
from unit27.load_data import build_dataset
from unit27.load_data_iter import build_iterator
from unit27.train_eval import train

if __name__ == "__main__":
    config = Config()
    print("Loading data...")
    vocab, train_data, dev_data, test_data = build_dataset(config, False)
    # 1. 批量加载数据
    train_iter = build_iterator(train_data, config, False)
    dev_iter = build_iterator(dev_data,config,False)

    config.n_vocab = len(vocab)
    # 2. 构建模型
    model = Model(config).to(config.device)
    print(model.parameters)

    # init_network(model)
    print(model.parameters)
    train(config, model, train_iter, dev_iter)
terator(train_data, config, False)
    dev_iter = build_iterator(dev_data,config,False)

    config.n_vocab = len(vocab)
    # 2. 构建模型
    model = Model(config).to(config.device)
    print(model.parameters)

    # init_network(model)
    print(model.parameters)
    train(config, model, train_iter, dev_iter)

你可能感兴趣的:(自然语言处理,分类,pytorch)