基于TensorFlow框架的CNN文本小说分类

项目需求说明:

  • 要求实现小说类别的多分类,给定一段文字,识别出是哪一种小说 

  • 项目需要使用TensorFlow的CNN算法  

  • 能够根据相应的数据集训练自己的模型并给出测试精度

项目环境配置:

  • windows10/windows7
  • python3.6.8(X64)
  • numpy=1.16.4
  • tensorflow=1.9.0
  • sklearn=0.20.3

数据集搜集标注:

有关小说的数据集非常难得,需要花钱从网上买,或者自己搜集标注。分享一下该项目对应的数据集:

基于TensorFlow框架的CNN文本小说分类_第1张图片

模型训练过程展示:

基于TensorFlow框架的CNN文本小说分类_第2张图片

模型测试:

基于TensorFlow框架的CNN文本小说分类_第3张图片

说明:如果训练的迭代次数增多的话,效果有可能会更好的。 

模型应用:

基于TensorFlow框架的CNN文本小说分类_第4张图片

输入一段小说内容,预测出该内容属于的小说类别是“言情” 。

结果中的警告不影响程序运行及结果,使用合适的python第三方库版本可消除警告。

预测核心代码展示:

# coding: utf-8

from __future__ import print_function

import os
import tensorflow as tf
import tensorflow.contrib.keras as kr
import warnings
warnings.filterwarnings("ignore")
from cnn_model import TCNNConfig, TextCNN
from data.cnews_loader import read_category, read_vocab

try:
    bool(type(unicode))
except NameError:
    unicode = str

base_dir = 'data/novel'
vocab_dir = os.path.join(base_dir, 'novel.vocab.txt')

save_dir = 'checkpoints/textcnn'
save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径


class CnnModel:
    def __init__(self):
        self.config = TCNNConfig()
        self.categories, self.cat_to_id = read_category()
        self.words, self.word_to_id = read_vocab(vocab_dir)
        self.config.vocab_size = len(self.words)
        self.model = TextCNN(self.config)

        self.session = tf.Session()
        self.session.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(sess=self.session, save_path=save_path)  # 读取保存的模型

    def predict(self, message):
        # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行
        content = unicode(message)
        data = [self.word_to_id[x] for x in content if x in self.word_to_id]

        feed_dict = {
            self.model.input_x: kr.preprocessing.sequence.pad_sequences([data], self.config.seq_length),
            self.model.keep_prob: 1.0
        }

        y_pred_cls = self.session.run(self.model.y_pred_cls, feed_dict=feed_dict)
        return self.categories[y_pred_cls[0]]


if __name__ == '__main__':
    cnn_model = CnnModel()
    test_demo = ['情侣二人相约公园观看月色']
    for i in test_demo:
        print('该小说的类别是:',cnn_model.predict(i))
 

备注:

1.对项目感兴趣或有问题可以私信我

2.可实现类似的相关文本分类

 

你可能感兴趣的:(项目,机器学习,TensorFlow,CNN,小说分类,文本分类,python)