深度学习笔记_神经网络八股功能扩展及实现应用

神经网络八股功能扩展

  • 1、自制数据集,解决本领域问题
  • 2、数据增强,扩充数据集
  • 3、断点续训,存取模型
  • 4、参数提取,存入文本可视化
  • 5、可视化绘图,查看训练效果
  • 6、应用程序,给图识物

1、自制数据集,解决本领域问题

制作可以用来训练和测试的数据集。
MNIST数据集、Fashion数据集是已经制作好的数据集,包括特征和标签。
这里仍利用MNIST数据集中的图片,进行数据集制作。
图片:黑底白字灰度图,28*28,每个像素为0-255的整数
标签:即对应图片的数字

fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train),(x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

实现自制数据集,模仿上述代码,即实现自制的
(x_train, y_train),(x_test, y_test) = fashion.load_data()
写出函数构建 (x_train, y_train) , (x_test, y_test)
具体过程忽略

2、数据增强,扩充数据集

数据增强,可以扩展数据集。
对图像的增强就是对图像的简单形变。
Tensorflow2给出了数据增强函数:
深度学习笔记_神经网络八股功能扩展及实现应用_第1张图片

from tensorflow.keras.preprocessing.image import ImageDataGenerator
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)  # 给数据增加一个维度,使数据和网络结构匹配
image_gen_train = ImageDataGenerator(
    rescale=1. / 1.,  # 如为图像,分母为255时,可归至0~1
    rotation_range=45,  # 随机45度旋转
    width_shift_range=.15,  # 宽度偏移
    height_shift_range=.15,  # 高度偏移
    horizontal_flip=True,  # 水平翻转
    zoom_range=0.5  # 将图像随机缩放阈量50%
)
image_gen_train.fit(x_train)  #这里是4维数据

原始数据:深度学习笔记_神经网络八股功能扩展及实现应用_第2张图片
数据增强后:
深度学习笔记_神经网络八股功能扩展及实现应用_第3张图片

增强数据后,model.fit()中输入训练集应修改

model.fit(image_gen_train.flow(x_train, y_train, batch_size=32), epochs=5, validation_data=(x_test, y_test),
          validation_freq=1)

3、断点续训,存取模型

在之前的练习中,代码运行一次后模型中的参数就清空了。而断点续训可以存取模型。
深度学习笔记_神经网络八股功能扩展及实现应用_第4张图片
load_weights(文件路径名)可以读取模型

tf.keras.callbacks.ModelCheckpoint(filepath=路径文件名,save_weights_only=True/False,save_best_only=True/False)可以保存模型。
在执行过程中history=model.fit(callbacks=[cp_callback])

保存模型后生成.ckpt文件夹(同步生成索引表.index),包括内容:
深度学习笔记_神经网络八股功能扩展及实现应用_第5张图片

读取/存放代码如下:

#设置模型参数路径
check_point_path = './checkpoint/fashion.ckpt'
#如果文件存在,则模型读取参数
if os.path.exists(check_point_path+'.index'):   #索引表
    print('****************load the model***************')
    model.load_weights(check_point_path)
#保存模型
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=check_point_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)
#执行模型时保存
history = model.fit(x_train,y_train,batch_size=32,epochs=5,validation_data=(x_test,y_test),validation_freq=1,
                    callbacks=[cp_callback])

4、参数提取,存入文本可视化

断点续训保存了训练过后的模型,本步骤将模型参数提取处理,存入文本。
提取可训练参数:model.trainable_variables返回模型中可训练的参数 np.set_printoptions(threshold = 超过多少省略提示)

np.set_printoptions(threshold=np.inf)  #设置打印间隔无限大#打印参数
print(model.trainable_variables)
#保存参数到txt\
file = open('./weights.txt','w')
for v in model.trainable_variables:
    file.write(str(v.shape),+'\n')
    file.write(str(v.numpy()),+'\n')
file.close()

5、可视化绘图,查看训练效果

将准确率上升,损失率下降可视化。acc曲线与loss曲线
在模型执行过程中history = history=model.fit(训练集数据,训练集标签,batch_size=,epochs=,validation_split=y用做测试集的比例,validation_freq=测试频率)
history:
训练集loss:loss
测试集loss:val_loss
训练集准确率:sparse_categorical_accuracy
测试集准确率:val_sparse_categorical_accuracy

# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.subplot(1, 2, 1)#1行两列,第一个子图
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

深度学习笔记_神经网络八股功能扩展及实现应用_第6张图片

6、应用程序,给图识物

要想实现模型可用,还需编写一个应用程序。
三大步骤如下
1、复现模型(前向传播)model=()
2、加载参数model.load()
3、预测结果result=model.predice()
代码:

换

    img_arr=img_arr/255.0                    #归一化

    print("img_arr:",img_arr.shape)
    x_predict = img_arr[tf.newaxis,...]  #神经网络是一个batch输入的,因此要增加一个维度x_predict(1,28,28)
    #进行预测
    result = model.predict(x_predict)
    # 最大概率值输出,返回预测结果
    pred = tf.argmax(result,axis=1)   
    
    print('\n')
    tf.print(pred)

注意:喂入模型预测的数据格式要和训练模型的数据格式一致
学习到这里使用的网络均是全连接层Dense
下面将进行卷积网络的学习

你可能感兴趣的:(笔记,神经网络,深度学习,机器学习,python)