输入的为csv格式数据,可输出train_loss和test_acc
# -*- coding: utf-8 -*-
import tensorflow as tf
import os
from matplotlib import pyplot as plt
from sklearn import preprocessing
#训练集、测试集的数据文件和标签文件所在的路径
x_file_path_train = 'D:/桌面文件夹/重采样后/800维信号数据(csv)/train/train_data.csv'
y_file_path_train = 'D:/桌面文件夹/重采样后/800维信号数据(csv)/train/train_label.csv'
x_file_path_test = 'D:/桌面文件夹/重采样后/800维信号数据(csv)/test/test_data.csv'
y_file_path_test = 'D:/桌面文件夹/重采样后/800维信号数据(csv)/test/test_label.csv'
#该函数目前正处于tensorflow的实验阶段,后期可能会加入相应的api,谨慎使用
def get_data(x_len,y_len,x_file_path,y_file_path,header_status_x = False,header_status_y = False):
"""制作训练集,输出为tensorflow可用的数据集
x_len:数据的维度
y_len:标签的维度
x_data:数据文件的存放路径
y_data:标签文件的存放路径
header_status_x:数据文件是否有标题行,默认为False
header_status_t:标签文件是否有标题行,默认为False"""
record_default = [[0.0]] * x_len
ds = tf.data.experimental.CsvDataset(x_file_path, record_default, header=header_status_x)
# ds_1 = preprocessing.MinMaxScaler(feature_range=(0, 1), copy=1)
# ds = ds_1.fit_transform(ds)
ds = ds.map(lambda *items: tf.stack(items))
label_record_default = [[0.0]] *y_len
ds_label = tf.data.experimental.CsvDataset(y_file_path, label_record_default, header=header_status_y)
ds_label = ds_label.map(lambda *items: tf.stack(items))
#区别点
dataset = tf.data.Dataset.zip((ds, ds_label)).batch(10).repeat()
#shuffle中的值越大打乱的程度越大
dataset = dataset.shuffle(1000)
return dataset
def get_data_1(x_len,y_len,x_file_path,y_file_path,header_status_x = False,header_status_y = False):
"""输入和输出的数据的内容和格式相似,区别在于是否有repeat操作"""
record_default = [[0.0]] * x_len
ds = tf.data.experimental.CsvDataset(x_file_path, record_default, header=header_status_x)
# ds_1 = preprocessing.MinMaxScaler(feature_range=(0,1),copy=1)
# ds = ds_1.fit_transform(ds)
ds = ds.map(lambda *items: tf.stack(items))
label_record_default = [[0.0]] *y_len
ds_label = tf.data.experimental.CsvDataset(y_file_path, label_record_default, header=header_status_y)
ds_label = ds_label.map(lambda *items: tf.stack(items))
#区别点
dataset = tf.data.Dataset.zip((ds, ds_label)).batch(10)
dataset = dataset.shuffle(1000)
return dataset
#生成训练集和测试集
dataset_train = get_data(800,1,x_file_path_train,y_file_path_train)
dataset_test = get_data_1(800,1,x_file_path_train,y_file_path_train)
################数据训练
#定义模型结构,输入为23维,输出为一维,采用bp网络
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(64, input_dim=800, activation='relu'),
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dense(128,activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
###优化器的确定,损失函数的选择
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
#模型权重的保存路径,ckpt文件
checkpoint_save_path = "./checkpoint/speed_up.ckpt"
#断点续训的设置,加载权重文件
if os.path.exists(checkpoint_save_path):
print('-------------load the model-----------------')
model.load_weights(checkpoint_save_path)
#权重保存的方式,此处选择只保存模型的权重,保存最优参数
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
save_weights_only=True,
save_best_only=True)
#模型训练的开始,目前大轮数为5,小轮数为1000,总共训练5000轮,每一个大轮之后测试一次,模型权重保存
history = model.fit(dataset_train,epochs=5,steps_per_epoch=1000, validation_data=dataset_test, validation_freq=1,
callbacks=[cp_callback])
model.summary()
#定义一个专门保存权重的txt文件
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.numpy()) + '\n')
file.close()
############################################### show ###############################################
#获得训练集和测试集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
#绘制acc曲线
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
#绘制loss曲线
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()