数据集是 IMDB Large Movie Review Dataset,包含用于训练的 25000 段带有明显情感倾向的电影评论,测试集有 25000 段。我们将会用此数据集训练一个二分类模型,用于判断一篇评论是积极的还是消极的
步骤:
"""
目的:对IMDB电影评论数据进行训练,预测分类
步骤:
1.读取电影评论数据
2.模型输入特征指定
3.模型训练与保存
"""
from tensorflow import keras
import tensorflow as tf
import functools
class Solution:
def get_imdb_data(self):
"""
获取数据
"""
imdb = keras.datasets.imdb
(x_train_data, y_train), (x_test_data, y_test) = imdb.load_data(num_words=5000)
# 对读取的评论的词的数字id列表,进行填充补齐
x_train = keras.preprocessing.sequence.pad_sequences(sequences=x_train_data, maxlen=200, padding='post',
value=0)
x_test = keras.preprocessing.sequence.pad_sequences(sequences=x_test_data, maxlen=200, padding='post', value=0)
return (x_train, y_train), (x_test, y_test)
def parse(self, X, y):
"""
生成feature_dict, label
"""
feature_dict = {"feature": X}
return feature_dict, y
def input_fn(self, X, y, batch_size, nums_epochs):
"""
训练输入estimator的数据格式
"""
dataset = tf.data.Dataset.from_tensor_slices((X, y)).map(self.parse).batch(batch_size).repeat(nums_epochs)
return dataset
def train(self, hidden_units, feature_clos, model_dir, train_inputfn):
"""
模型训练
"""
return tf.estimator.DNNClassifier(hidden_units=[hidden_units], feature_columns=[feature_clos],
model_dir=model_dir).train(input_fn=train_inputfn)
def evaluate(self, model, test_inputfn):
"""
模型评估
"""
return model.evaluate(test_inputfn)
def process(self):
# 1.获取数据
(x_train, y_train), (x_test, y_test) = self.get_imdb_data()
train_inputfn = functools.partial(self.input_fn, x_train, y_train, 64, 5)
# print(train_inputfn)
test_inputfn = functools.partial(self.input_fn, x_test, y_test, 64, 1)
# print(test_inputfn)
# 2.模型输入特征指定
column = tf.feature_column.categorical_column_with_identity(key='feature', num_buckets=5000)
t_embedding_column = tf.feature_column.embedding_column(column, dimension=50)
# 3.模型训练
classifier = self.train(100, t_embedding_column, './tmp/model/dnn_txt1/', train_inputfn)
# 4.模型评估
results = self.evaluate(classifier, test_inputfn)
print(results)
if __name__ == '__main__':
s = Solution()
s.process()
注:keras.preprocessing.sequence.pad_sequences(sequences, maxlen=None, dtype=‘int32’,
padding=‘pre’, truncating=‘pre’, value=0.)
- sequences:浮点数或整数构成的两层嵌套列表
- maxlen:None或整数,为序列的最大长度。大于此长度的序列将被截短,小于此长度的序列将在后部填0.
- dtype:返回的numpy array的数据类型
- padding:‘pre’或‘post’,确定当需要补0时,在序列的起始还是结尾补
- truncating:‘pre’或‘post’,确定当需要截断序列时,从起始还是结尾截断
- value:浮点数,此值将在填充时代替默认的填充值0