text_cnn:文本分类实战(二)参数设置和构建网络

文章目录

  • CNN
  • 实战

CNN

用的是tf.layers.conv1d
模型模块用来搭建神经网络。

实战

之前在一章完成了数据处理部分,可以将文本转换成等大小的图片格式的矩阵,标签可以转换成数字后再one-hot显示

class TCNNconfig(object)# CNN 配置参数
	n_dim=64   #嵌入维度  也就是以词为单位的维度(大白话:词变成长度为64的数字向量)
	seg_length = 300   #  有300个词  矩阵就是300*64  大小可调    
	num_classes = 10   #  y  10类
	num_fliters = 256  # 卷积核数目 
	kernel_size =  5   # 卷积核尺寸  因为使用conv1d 说白了就是以大小为5的窗口在x的每一行上进行卷积
    
    hidden_dim = 128   #  全连接层神经元
	drop_keep_prob = 0.5   #  dropout保留比例
	learning_rate = 1e-3  # 学习率

	batch_size = 64
	num_epochs = 10 

	print_per_batch = 100  # 每100个batch 输出一次结果
	save_per_batch = 10  # 每多少epochs 存入 tensorboard
·
	...
	

设置了一个类用存放配置参数,方便下面的调用。
接下来要搭建网络模型。

class TCNN(object):
	
	def __init__(self,config):
		self.config = config
	# 参数设置得传进来才对
		
		self.input_x = tf.placeholder(tf.float32,[None, self.config.seq_length, self.config.n_dim], name='input_x')
		self.input_y = tf.placeholder(tf.float32,[None, self.config.num_class], name='input_y')
		self.keep_drop_out = tf.placeholder(tf.float32, name='keep_prob')
		
		self.cnn()    #调用
	def cnn(self):  #  默认在GPU环境运行,(已安装)
		# 注意此时文本还没有进行向量化表示
		with tf.name_scope('cnn'):
			conv = tf.layers.conv1d(self.input_x, self.config.num_fliters, self.config.kernel_size, name='conv')
			gmp = tf.reduce_max(conv, reduction_indices = [1], name = 'gmp')
			
		with tf.name_scope('score'):
			fc = tf.layers.dense(gmp, self.config.hidden_dim, name = 'fc1')	
			fc = tf.contrib.layers.dropout(fc, self.keep_prob)
			fc = tf.nn.relu(fc)
		
			self.logits = tf.layers.dense(fc, self.config.hidden_dim, name='fc2')
			self.y_pred = tf.argmax(tf.nn.softmax(self.logits), 1)
			
		with tf.name_scope('optimize'):
			#  用来优化
			cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits = self.logits, labels = self.input_y)
			self.loss = tf.reduce_mean(cross_entropy)
			self.optimizer = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)

		with tf.name_scope('acc'):
			
			correct_pred = tf.equal(tf.argmax(self.input_y), self.y_pred)
			self.acc = tf.reduce_mean(tf.cast(correct_pred,tf.float32))




		...

当网络定义好,接下来就是如何调用前两个模块,完整代码会在最后一个模块给出链接。

你可能感兴趣的:(text_cnn:文本分类实战(二)参数设置和构建网络)