首先,一定要吐槽一下,千万不要拿笔记本跑训练集很大的模型,不然真的能跑到吐血三升,而且电脑卡到宕机之后好不容易等到电脑能动了就是报错“内存不足”,简直能逼死我。每次跑程序都得把虚存开到最大,然后看着我的电脑在那卡卡卡,我都在想会不会把我电脑跑废了。当然结局是美好的,因为模型跑出来的审核结果正确率在94%以上,简直是惊喜到不行,嗯,基于这个理由,给我配工作站让我专门跑算法了,开心~
好了,言归正传,由于THUCTC模型对于文章审核方面的正确率很低,可能是由于切词结果和词的权重不合适的关系,导致我需要去寻找新的合适的模型和算法,最后选择了谷歌的gensim(主要针对自然语言的处理,里面也包括tf-idf)构建和训练词向量,神经网络选用了CNN卷积神经网络,分词选用的jieba分词,按照词性筛选需要的关键词构建词向量。比较核心的就是这几个了。还有一些对文档的处理,等到代码的时候就知道了,主要是筛除一些干扰的无用的词或标点,将文章统一长度等
话不多说,让我们来看看代码吧。
首先就是训练文档的准备,以什么形式不重要,只要后续能够读取做成模型的参数就行,我这里选用的是tensorflow的文本格式.tfRecords。如果是用普通文本的就自己读写文件就好了,我这里就以我用的tfRecords的存取来说明了。
代码如下,代码中也包括对于预测文档的准备及粗处理:
class loadData:
# =============================================================================
# 从配置文件中读取信息start
# =============================================================================
def __init__(self, base_dir = '.', path="database.conf"):
self.base_dir = base_dir
self.path = path
self.cf = configparser.ConfigParser()
self.cf.read(self.path)
def get_configInfo(self, field, key):
result = ""
try:
result = self.cf.get(field, key)
except:
result = ""
return result
def read_config(self, config_file_path="database.conf", field="db"):
cf = configparser.ConfigParser()
try:
cf.read(config_file_path)
db_host = cf.get(field, "db_host")
db_user = cf.get(field, "db_user")
db_pass = cf.get(field, "db_pass")
db_name = cf.get(field, "db_name")
db_charset = cf.get(field, "db_charset")
db_port = cf.get(field, "db_port")
except Exception as inst:
traceback.print_exc()
print(type(inst))
print(inst.args)
print(inst)
return db_host,db_user,db_pass,db_name,db_charset,db_port
# =============================================================================
# 从配置文件中读取信息end
# =============================================================================
# =============================================================================
# 数据库相关操作start
# =============================================================================
def connDB(self,config_file_path="database.conf"): # 连接数据库
db_host,db_user,db_pass,db_name,db_charset,db_port = self.read_config(config_file_path)
conn = pymysql.connect(host=db_host,user=db_user,password=db_pass,db=db_name,charset=db_charset,port=int(db_port))
cur = conn.cursor(cursor=pymysql.cursors.DictCursor)
return conn, cur
def exeUpdate(self,conn, cur, sql): # 更新,插入或删除操作
sta = cur.execute(sql)
conn.commit()
return sta
def exeQuery(self,cur, sql): # 查找操作
cur.execute(sql)
return cur.fetchall()
def connClose(self,conn, cur): # 关闭连接,释放资源
cur.close()
conn.close()
# =============================================================================
# 数据库相关操作end
# =============================================================================
# =============================================================================
# 从数据库中读取要训练的数据存储到TFRecords文件中,以待模型训练使用,start
# =============================================================================
def createTestFile(self,kind,date = 0):
tableName = {0:"bidding",1:"article"}
description = {0:"content",1:"details"}
map_type = {0: 2, 1: 1}
conn,cur = self.connDB()
content_data = []
content_id = []
state = 1
mark = 0
pagesize = 100
loop = True
daytime = int(time.time()) - date*3600*24
while (loop):
sql = "select * from " + tableName[kind] + " where state = " + str(
state) + " and FROM_UNIXTIME(time,'%Y-%m-%d') = FROM_UNIXTIME(" + str(
daytime) + ",'%Y-%m-%d') order by id desc limit " + str(mark * pagesize) + " , " + str(pagesize)
print(sql)
data_cur = self.exeQuery(cur, sql)
if data_cur.__len__() > 0:
for data in data_cur:
text = data["title"] +"\t" + data[description[kind]]
content = self.seperate_line(self.clean_str(text))#对将要进行训练的训练集文件进行降噪处理
content_data.append(content)
content_id.append(data["id"])
else:
loop = False
mark = mark + 1
print('已经处理完第 %d 页' % (mark))
print("待分类数据生成成功!")
self.connClose(conn, cur)
return content_data, content_id
def clean_str(self, string):
string = re.sub('\s+', "", string)
r1 = u'[A-Za-z0-9’!"#$%&\'()*+,-./:;<=>?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+'
string = re.sub(r1, ' ', string)
return string.strip()
def seperate_line(self, line):
length = len(line)
line = line[0:int(length / 2)]
line = pseg.cut(line)
new_line = []
for words, flag in line:
if flag == 'nr' or flag == 'ns':
continue
if len(flag) == 0:
continue
if flag[0:1] != 'n' and flag != 'v':
# if flag[0:1] != 'n':
continue
new_line.append(words)
return ''.join([word + " " for word in new_line])
def batch_iter(self, data, batch_size, epoch_num, shuffle=True):
data = np.array(data)
data_size = len(data)
batch_num_per_epoch = int((data_size - 1 / batch_size)) + 1
for epoch in range(epoch_num):
if shuffle:
shuffle_indices = np.random.permutation(np.arange(data_size))
shuffled_data = data[shuffle_indices]
else:
shuffled_data = data
for batch_num in range(batch_num_per_epoch):
start_idx = batch_num * batch_size
end_idx = min((batch_num + 1) * batch_size, data_size)
yield shuffled_data[start_idx: end_idx]
def padding_sentences(self, input_sentences, padding_token, padding_sentence_length=None):
sentences = [sentences.split() for sentences in input_sentences]
max_sentence_length = padding_sentence_length if padding_sentence_length is not None else max(
[len(sentence) for sentence in sentences])
for sentence in sentences:
if len(sentence) > max_sentence_length:
sentence = sentence[:max_sentence_length]
else:
sentence.extend([padding_token] * (max_sentence_length - len(sentence)))
return (sentences, max_sentence_length)
def saveDict(self, input_dict, output_file):
with open(output_file, 'wb') as f:
pickle.dump(input_dict, f)
def loadDict(self, dict_file):
output_dict = None
with open(dict_file, 'rb') as f:
output_dict = pickle.load(f)
return output_dict
def check_padding_sentences(self,input_sentences, x_raw):
valid_sentences = []
new_x_raw = []
valid_length = len(input_sentences[0])
print('checking padding sentences..., valid length : ', valid_length)
for sentence in input_sentences:
if len(sentence) == valid_length:
valid_sentences.append(sentence)
new_x_raw.append(x_raw[input_sentences.index(sentence)])
return (valid_sentences, new_x_raw)
def createTrainFile(self,kind,filename):
tableName = {0:"bidding",1:"article"}
description = {0:"content",1:"details"}
conn,cur = self.connDB()
mark = 0
state = 1
pagesize = 100
writer = tf.python_io.TFRecordWriter(os.path.join(self.base_dir, filename))
loop = True
while(loop):
sql = "select * from " + tableName[kind] + " where state = "+str(state)+" limit "+str(mark*pagesize)+" , "+str( pagesize)
data_cur = self.exeQuery(cur, sql)
if data_cur.__len__() > 0:
for data in data_cur:
text = data["title"] +"\t" + data[description[kind]]
reg = re.compile('<[^>]*>')
content = self.filter_tags(reg.sub('',text).replace('\n','').replace(' ',''))#对将要进行训练的训练集文件进行降噪处理
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[int(data["article_class_id"].replace(',',''))])),
'content': tf.train.Feature(bytes_list=tf.train.BytesList(value=[content.encode()]))
}))
writer.write(example.SerializeToString()) #序列化为字符串
else:
loop = False
mark = mark + 1
print('已经处理完第 %d 页' %(mark))
writer.close()
print("训练文件生成成功!")
def insertPredict(self,content_id, all_pre):
conn, cur = self.connDB()
sqlu = "update article set article_class_id = CASE id"
for i in range(len(content_id)):
sqlu = sqlu + " WHEN " + str(content_id[i]) + " THEN " + str(int(all_pre[i]))
sql = " INSERT INTO map (id, article_class_id, article_id, type, class_sort) values (0,"+str(int(all_pre[i]))+", "+str(content_id[i])+", 1, 0) "
print(sql)
self.exeUpdate(conn, cur, sql)
sqlu = sqlu + " END WHERE id IN"
ids = ''
if len(content_id) == 1:
ids = "(%d)" % content_id[0]
else:
ids = str(tuple(content_id))
sqlu = sqlu + ids
self.exeUpdate(conn, cur, sqlu)
self.connClose(conn, cur)
return 0
# 过滤字符串中的指定字符
def filter_tags(self, htmlstr):
re_cdata = re.compile('//]*//\]\]>', re.I) # 匹配CDATA
re_script = re.compile('<\s*script[^>]*>[^<]*<\s*/\s*script\s*>', re.I) # Script
re_style = re.compile('<\s*style[^>]*>[^<]*<\s*/\s*style\s*>', re.I) # style
re_br = re.compile('
') # 处理换行
re_h = re.compile(']*>') # HTML标签
re_comment = re.compile('') # HTML注释
blank_line = re.compile('\n+')
# 过滤匹配内容
s = re_cdata.sub('', htmlstr) # 去掉CDATA
s = re_script.sub('', s) # 去掉SCRIPT
s = re_style.sub('', s) # 去掉style
s = re_br.sub('\n', s) # 将br转换为换行
s = re_h.sub('', s) # 去掉HTML 标签
s = re_comment.sub('', s) # 去掉HTML注释
s = blank_line.sub('\n', s) # 去掉多余的空行
return s
def is_chinese(self, uchar):
"""判断一个unicode是否是汉字"""
if uchar >= u'\u4e00' and uchar <= u'\u9fa5':
return True
else:
return False
def format_str(self, content):
content = np.unicode(content, 'utf-8')
content_str = ''
for i in content:
if self.is_chinese(i):
content_str = content_str + i
return content_str
下面对于训练文档中的数据进行细化的处理,用以提供模型训练的输入:
# encoding=utf8
import re
import jieba.posseg as pseg
import tensorflow as tf
import os
import numpy as np
import codecs
import pickle
'''
input text data and it's label
'''
tf.flags.DEFINE_integer("class_num", 2, "Number of article class to store (default: 2)")
FLAGS = tf.flags.FLAGS
FLAGS.is_parsed()
print("\nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):
print("{}={}".format(attr.upper(), value))
print("")
def load_data_and_label(file_name):
"""
Loads MR polarity data from files, splits the data into words and generates labels.
Returns split sentences and labels.
"""
# Load data from files
record_iterator = tf.python_io.tf_record_iterator(path=file_name)
content_data = []//按照分类的个数,将文档信息存储到相应的分类数组里
for i in range(0, FLAGS.class_num):
content_data.append([])
label_data = []
labelToVec = []
x_train = []
classes = []
for string_record in record_iterator:
example = tf.train.Example()
example.ParseFromString(string_record)
label = example.features.feature['label'].int64_list.value[0]
classes.append(label)
# content_data.append(str(content.decode("utf-8")))
content = example.features.feature['content'].bytes_list.value[0]
content = seperate_line(clean_str(np.unicode(content, 'utf-8')))
if (label not in labelToVec):
labelToVec.append(label)
content_data[labelToVec.index(label)].append(content)
for c in content_data:
x_train += c
idx = 0
for i in range(len(content_data)):
labelvec = [0]*FLAGS.class_num
labelvec[idx] = 1
idx += 1
tmplabel = [labelvec for _ in content_data[i]]
label_data.append(tmplabel)
# combine label
label_data = np.concatenate(label_data, 0)
return [x_train, label_data, labelToVec]
def clean_str(string):
string = re.sub('\s+', "", string)
r1 = u'[A-Za-z0-9’!"#$%&\'()*+,-./:;<=>?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+'
string = re.sub(r1, ' ', string)
return string.strip()
def seperate_line(line):
length = len(line)
line = line[0:int(length/2)]
line = pseg.cut(line)//jieba分词
new_line = []
for words,flag in line:
if flag == 'nr' or flag=='ns'://根据词性保留关键词
continue
if len(flag) == 0:
continue
if flag[0:1] !='n' and flag != 'v':
#if flag[0:1] != 'n':
continue
new_line.append(words)
return ''.join([word + " " for word in new_line])
def batch_iter(data, batch_size, epoch_num, shuffle = True):
data = np.array(data)
data_size = len(data)
batch_num_per_epoch = int((data_size - 1 / batch_size)) + 1
for epoch in range(epoch_num):
if shuffle:
shuffle_indices = np.random.permutation(np.arange(data_size))
shuffled_data = data[shuffle_indices]
else:
shuffled_data = data
for batch_num in range(batch_num_per_epoch):
start_idx = batch_num * batch_size
end_idx = min((batch_num + 1) * batch_size, data_size)
yield shuffled_data[start_idx : end_idx]
def padding_sentences(input_sentences, padding_token, padding_sentence_length = None)://将文档统一长度,便于处理
sentences = [sentences.split() for sentences in input_sentences]
max_sentence_length = padding_sentence_length if padding_sentence_length is not None else max([len(sentence) for sentence in sentences])
for sentence in sentences:
if len(sentence) > max_sentence_length:
sentence = sentence[:max_sentence_length]
else:
sentence.extend([padding_token] * (max_sentence_length - len(sentence)))
return (sentences, max_sentence_length)
def saveDict(input_dict, output_file):
with open(output_file, 'wb') as f:
pickle.dump(input_dict, f)
def loadDict(dict_file):
output_dict = None
with open(dict_file, 'rb') as f:
output_dict = pickle.load(f)
return output_dict
构建CNN神经网络:
import tensorflow as tf
import numpy as np
class TextCNN(object):
'''
A CNN for text classification
Uses and embedding layer, followed by a convolutional, max-pooling and softmax layer.
'''
def __init__(
self, sequence_length, num_classes,
embedding_size, filter_sizes, num_filters, l2_reg_lambda=0.0):
# Placeholders for input, output, dropout
self.input_x = tf.placeholder(tf.float32, [None, sequence_length, embedding_size], name = "input_x")
self.input_y = tf.placeholder(tf.float32, [None, num_classes], name = "input_y")
self.dropout_keep_prob = tf.placeholder(tf.float32, name = "dropout_keep_prob")
# Keeping track of l2 regularization loss (optional)
l2_loss = tf.constant(0.0)
# Embedding layer
# self.embedded_chars = [None(batch_size), sequence_size, embedding_size]
# self.embedded_chars = [None(batch_size), sequence_size, embedding_size, 1(num_channels)]
self.embedded_chars = self.input_x
self.embedded_chars_expended = tf.expand_dims(self.embedded_chars, -1)
# Create a convolution + maxpool layer for each filter size
pooled_outputs = []
for i, filter_size in enumerate(filter_sizes):
with tf.name_scope("conv-maxpool-%s" % filter_size):
# Convolution layer
filter_shape = [filter_size, embedding_size, 1, num_filters]
W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="W")
b = tf.Variable(tf.constant(0.1, shape=[num_filters]), name="b")
conv = tf.nn.conv2d(
self.embedded_chars_expended,
W,
strides=[1,1,1,1],
padding="VALID",
name="conv")
# Apply nonlinearity
h = tf.nn.relu(tf.nn.bias_add(conv, b), name = "relu")
# Maxpooling over the outputs
pooled = tf.nn.max_pool(
h,
ksize=[1, sequence_length - filter_size + 1, 1, 1],
strides=[1,1,1,1],
padding="VALID",
name="pool")
pooled_outputs.append(pooled)
# Combine all the pooled features
num_filters_total = num_filters * len(filter_sizes)
self.h_pool = tf.concat(pooled_outputs, 3)
self.h_pool_flat = tf.reshape(self.h_pool, [-1, num_filters_total])
# Add dropout
with tf.name_scope("dropout"):
self.h_drop = tf.nn.dropout(self.h_pool_flat, self.dropout_keep_prob)
# Final (unnomalized) scores and predictions
with tf.name_scope("output"):
W = tf.get_variable(
"W",
shape = [num_filters_total, num_classes],
initializer = tf.contrib.layers.xavier_initializer())
b = tf.Variable(tf.constant(0.1, shape=[num_classes], name = "b"))
l2_loss += tf.nn.l2_loss(W)
l2_loss += tf.nn.l2_loss(b)
self.scores = tf.nn.xw_plus_b(self.h_drop, W, b, name = "scores")
self.predictions = tf.argmax(self.scores, 1, name = "predictions")
# Calculate Mean cross-entropy loss
with tf.name_scope("loss"):
losses = tf.nn.softmax_cross_entropy_with_logits(logits = self.scores, labels = self.input_y)
self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss
# Accuracy
with tf.name_scope("accuracy"):
correct_predictions = tf.equal(self.predictions, tf.argmax(self.input_y, 1))
self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name = "accuracy")
开始用该网络训练模型,并保存合适的模型以便预测使用:
#! /usr/bin/env python
# encoding: utf-8
import tensorflow as tf
import numpy as np
import os
import time
import datetime
import data_helper
import word2vec_helpers
from text_cnn import TextCNN
from scipy.sparse import csr_matrix
import random
# Parameters
# =======================================================
# Data loading parameters
tf.flags.DEFINE_float("dev_sample_percentage", .1, "Percentage of the training data to use for validation")
tf.flags.DEFINE_integer("num_labels", 2, "Number of labels for data. (default: 12)")
# Model hyperparameters
tf.flags.DEFINE_integer("embedding_dim", 128, "Dimensionality of character embedding (default: 128)")
tf.flags.DEFINE_string("filter_sizes", "3,4,5", "Comma-spearated filter sizes (default: '3,4,5')")
tf.flags.DEFINE_integer("num_filters", 128, "Number of filters per filter size (default: 128)")
tf.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability (default: 0.5)")
tf.flags.DEFINE_float("l2_reg_lambda", 0.0, "L2 regularization lambda (default: 0.0)")
# Training paramters
tf.flags.DEFINE_integer("batch_size", 10, "Batch Size (default: 20)")
tf.flags.DEFINE_integer("num_epochs", 200, "Number of training epochs (default: 200)")
tf.flags.DEFINE_integer("evaluate_every", 100, "Evalue model on dev set after this many steps (default: 100)")
tf.flags.DEFINE_integer("checkpoint_every", 100, "Save model after this many steps (defult: 100)")
tf.flags.DEFINE_integer("num_checkpoints", 5, "Number of checkpoints to store (default: 5)")
# Misc parameters
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")
# Parse parameters from commands
FLAGS = tf.flags.FLAGS
FLAGS.is_parsed()
print("\nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):
print("{}={}".format(attr.upper(), value))
print("")
# Prepare output directory for models and summaries
# =======================================================
timestamp = str(int(time.time()))
out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
print("Writing to {}\n".format(out_dir))
if not os.path.exists(out_dir):
os.makedirs(out_dir)
# Data preprocess
# =======================================================
# Load data
print("Loading data...")
x_text, y, labelToVec = data_helper.load_data_and_label('train.tfrecords')
print("x_text length: ", len(x_text))
print("y length: ", y.shape)
#shuffle x_text and y
#np.random.seed(10)
#x_test = np.random.permutation(x_text)
#y = np.radom.permutation(y)
#x_text = x_text[0:1000]
#y = y[0:1000]
#random select a part of the original data
new_x_text = []
new_y = []
lens = len(x_text)
for i in range(lens):
rand_idx = random.randint(0, len(x_text)-1)
#rand_y = random.randint(0, len(x_text))
new_x_text.append(x_text[rand_idx])
new_y.append(y[rand_idx])
print("new_x_text length: %d" % len(new_x_text))
print("new_y length: %d" % len(new_y))
# embedding vector
print("Padding sentences...")
sentences, max_document_length = data_helper.padding_sentences(new_x_text, '' ) #max_document_length =
print("embedding_sentences...")
all_vectors = word2vec_helpers.embedding_sentences(sentences, embedding_size = FLAGS.embedding_dim, file_to_save = os.path.join(out_dir, 'trained_word2vec.model'))
print("all_vectors length %d * %d * %d : " % (len(all_vectors) , len(all_vectors[0]) , len(all_vectors[0][0])))
#x = np.array(all_vectors) ## this operation could lead to memory error!!!
#TODO: transform large vectors into sparse matrix
x = np.asarray(all_vectors)
y = np.asarray(new_y)
print("x.shape = {}".format(x.shape))
print("y.shape = {}".format(y.shape))
# Save params
training_params_file = os.path.join(out_dir, 'training_params.pickle')
params = {'num_labels' : FLAGS.num_labels, 'max_document_length' : max_document_length, 'labelToVec' : labelToVec}
data_helper.saveDict(params, training_params_file)
# Shuffle data randomly
np.random.seed(10)
shuffle_indices = np.random.permutation(np.arange(len(y)))
x_shuffled = x[shuffle_indices]
y_shuffled = y[shuffle_indices]
# Split train/test set
# TODO: This is very crude, should use cross-validation
dev_sample_index = -1 * int(FLAGS.dev_sample_percentage * float(len(y)))
x_train, x_dev = x_shuffled[:dev_sample_index], x_shuffled[dev_sample_index:]
y_train, y_dev = y_shuffled[:dev_sample_index], y_shuffled[dev_sample_index:]
print("Train/Dev split: {:d}/{:d}".format(len(y_train), len(y_dev)))
# Training
# =======================================================
saveacc = 0
saveloss = 1
with tf.Graph().as_default():
session_conf = tf.ConfigProto(
allow_soft_placement = FLAGS.allow_soft_placement,
log_device_placement = FLAGS.log_device_placement)
sess = tf.Session(config = session_conf)
with sess.as_default():
cnn = TextCNN(
sequence_length = x_train.shape[1],
num_classes = y_train.shape[1],
embedding_size = FLAGS.embedding_dim,
filter_sizes = list(map(int, FLAGS.filter_sizes.split(","))),
num_filters = FLAGS.num_filters,
l2_reg_lambda = FLAGS.l2_reg_lambda)
# Define Training procedure
global_step = tf.Variable(0, name="global_step", trainable=False)
optimizer = tf.train.AdamOptimizer(1e-3)
grads_and_vars = optimizer.compute_gradients(cnn.loss)
train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)
# Keep track of gradient values and sparsity (optional)
grad_summaries = []
for g, v in grads_and_vars:
if g is not None:
grad_hist_summary = tf.summary.histogram("{}/grad/hist".format(v.name), g)
sparsity_summary = tf.summary.scalar("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
grad_summaries.append(grad_hist_summary)
grad_summaries.append(sparsity_summary)
grad_summaries_merged = tf.summary.merge(grad_summaries)
# Output directory for models and summaries
print("Writing to {}\n".format(out_dir))
# Summaries for loss and accuracy
loss_summary = tf.summary.scalar("loss", cnn.loss)
acc_summary = tf.summary.scalar("accuracy", cnn.accuracy)
# Train Summaries
train_summary_op = tf.summary.merge([loss_summary, acc_summary, grad_summaries_merged])
train_summary_dir = os.path.join(out_dir, "summaries", "train")
train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)
# Dev summaries
dev_summary_op = tf.summary.merge([loss_summary, acc_summary])
dev_summary_dir = os.path.join(out_dir, "summaries", "dev")
dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)
# Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
checkpoint_prefix = os.path.join(checkpoint_dir, "model")
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)
# Initialize all variables
sess.run(tf.global_variables_initializer())
def train_step(x_batch, y_batch):
global saveacc, saveloss
"""
A single training step
"""
feed_dict = {
cnn.input_x: x_batch,
cnn.input_y: y_batch,
cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
}
_, step, summaries, loss, accuracy = sess.run(
[train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy],
feed_dict)
time_str = datetime.datetime.now().isoformat()
print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
train_summary_writer.add_summary(summaries, step)
if(saveacc < accuracy or (saveacc == accuracy and saveloss > loss)):
saveacc = accuracy
saveloss = loss
path = saver.save(sess, checkpoint_prefix, global_step=step)
print("Saved model checkpoint to {}\n".format(path))
def dev_step(x_batch, y_batch, writer=None):
"""
Evaluates model on a dev set
"""
feed_dict = {
cnn.input_x: x_batch,
cnn.input_y: y_batch,
cnn.dropout_keep_prob: 1.0
}
step, summaries, loss, accuracy = sess.run(
[global_step, dev_summary_op, cnn.loss, cnn.accuracy],
feed_dict)
time_str = datetime.datetime.now().isoformat()
print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
if writer:
writer.add_summary(summaries, step)
# Generate batches
batches = data_helper.batch_iter(
list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs)
# Training loop. For each batch...
for batch in batches:
x_batch, y_batch = zip(*batch)
train_step(x_batch, y_batch)
current_step = tf.train.global_step(sess, global_step)
if current_step % FLAGS.evaluate_every == 0:
print("\nEvaluation:")
dev_step(x_dev, y_dev, writer=dev_summary_writer)
print("")
# if current_step % FLAGS.checkpoint_every == 0:
# path = saver.save(sess, checkpoint_prefix, global_step=current_step)
# print("Saved model checkpoint to {}\n".format(path))
最后当然就是众望所归的预测啦~
#! /usr/bin/env python
import tensorflow as tf
import numpy as np
import os
import word2vec_helpers
import loadData as loadData
import csv
# Parameters
# ==================================================
# Eval Parameters
tf.flags.DEFINE_integer("num_labels", 12, "Number of labels for data. (default: 12)")
# Misc Parameters
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")
tf.flags.DEFINE_string("checkpoint_dir", ".\\article_model\\checkpoints", "Checkpoint directory from training run")
tf.flags.DEFINE_string("base_dir", ".", "files base_dir")
tf.flags.DEFINE_string("config_file", "database.conf", "database config file eg:'.conf'")
#model hyperparameters
tf.flags.DEFINE_integer('embedding_dim', 80,'dimensionality of characters')
tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (default: 64)")
tf.flags.DEFINE_integer("date", 0, "date distance(default: 0)")
FLAGS = tf.flags.FLAGS
FLAGS.is_parsed()
print("\nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):
print("{}={}".format(attr.upper(), value))
print("")
# validate
# ==================================================
# validate checkout point file
checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if checkpoint_file is None:
print("Cannot find a valid checkpoint file!")
exit(0)
print("Using checkpoint file : {}".format(checkpoint_file))
# validate word2vec model file
trained_word2vec_model_file = os.path.join(FLAGS.checkpoint_dir, "..", "trained_word2vec.model")
if not os.path.exists(trained_word2vec_model_file):
print("Word2vec model file \'{}\' doesn't exist!".format(trained_word2vec_model_file))
print("Using word2vec model file : {}".format(trained_word2vec_model_file))
# validate training params file
training_params_file = os.path.join(FLAGS.checkpoint_dir, "..", "training_params.pickle")
if not os.path.exists(training_params_file):
print("Training params file \'{}\' is missing!".format(training_params_file))
print("Using training params file : {}".format(training_params_file))
# Load params
data = loadData.loadData(FLAGS.base_dir, FLAGS.config_file)
params = data.loadDict(training_params_file)
num_labels = int(params['num_labels'])
max_document_length = int(params['max_document_length'])
labelToVec = list(params['labelToVec'])
#max_document_length = 944
# Load data
x_raw, content_id = data.createTestFile(1,FLAGS.date)
# Get Embedding vector x_test
print('Padding sentence...')
sentences, max_document_length = data.padding_sentences(x_raw, '' , padding_sentence_length = max_document_length)
print('sentences length : %d , max_document_length : %d' % (len(sentences), max_document_length))
sentences, new_x_raw = data.check_padding_sentences(sentences, x_raw)
all_vectors = word2vec_helpers.embedding_sentences(sentences,embedding_size = 128, file_to_load = trained_word2vec_model_file)
print('all_vectors length: %d' % len(all_vectors[0]))
x_test = np.array(all_vectors)
print("x_test.shape = {}".format(x_test.shape))
print('x_test_shape: ' , x_test.shape, " ", len(x_test) ," " , len(x_test[0]) ," " , len(x_test[0][0]))
print('list x_test ', len(list(x_test)))
# Evaluation
# ==================================================
print("\nEvaluating...\n")
checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
graph = tf.Graph()
with graph.as_default():
session_conf = tf.ConfigProto(
allow_soft_placement=FLAGS.allow_soft_placement,
log_device_placement=FLAGS.log_device_placement)
sess = tf.Session(config=session_conf)
with sess.as_default():
# Load the saved meta graph and restore variables
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess, checkpoint_file)
# Get the placeholders from the graph by name
input_x = graph.get_operation_by_name("input_x").outputs[0]
# input_y = graph.get_operation_by_name("input_y").outputs[0]
dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
# Tensors we want to evaluate
predictions = graph.get_operation_by_name("output/predictions").outputs[0]
# Generate batches for one epoch
batches = data.batch_iter(list(x_test), FLAGS.batch_size, 1, shuffle=False)
# Collect the predictions here
all_predictions = []
for x_test_batch in batches:
batch_predictions = sess.run(predictions, {input_x: x_test_batch, dropout_keep_prob: 1.0})
all_predictions = np.concatenate([all_predictions, batch_predictions])
print( 'prediction data num: ', len(all_predictions))
all_pre = []
for pre in all_predictions:
if(pre <= len(labelToVec)):
all_pre.append(labelToVec[int(pre)])
# Save the evaluation to database
#data.insertPredict(content_id,all_pre)
print("SUCCESS!")
# Save the evaluation to a csv
predictions_human_readable = np.column_stack((np.array(new_x_raw), all_predictions, np.array(content_id)))
out_path = os.path.join(FLAGS.checkpoint_dir, "..", "prediction.csv")
print("Saving evaluation to {0}".format(out_path))
with open(out_path, 'w') as f:
csv.writer(f).writerows(predictions_human_readable)
至此,整个程序就完整了,实验的结果还不错,如果有什么更好的方法,大家也可以交流一下~共同进步嘛