【神经网络八股扩展】:数据增强

课程来源:人工智能实践:Tensorflow笔记2

文章目录

  • 前言
  • TensorFlow2数据增强函数
  • 数据增强+网络八股代码:
  • 总结


前言

本讲目标:数据增强,增大数据量
关于我们为何要使用数据增强以及常用的几种数据增强的手法,可以看看下面的文章,虽说是翻译的,但仍有可鉴之处:
数据增强(Data Augmentation)


TensorFlow2数据增强函数

对图像的增强就是对图像的简单形变,用来应对因为拍照角度不同引起的图片形变。
TensorFlow2给出了数据增强函数

image_gen_train =tf.keras.preprocessing.image.ImageDataGenerator(
rescale = 所有数据将乘以该数值
rotation_range =随机旋转角度数范围
width_shift_range = 随机宽度偏移量
height_shift_range =随机高度偏移量
horizontal_flip =是否随机水平翻转
zoom_range =随机缩放的范围[1-n,1+n])
image_gen_train.fit(x_train)
fit需要输入4维数据,所以将x_train reshape为(60000,28,28,1)
这个1表示单通道灰度值
model.fit同步更新为.flow形式

划红线的部分为需要注意的地方:
【神经网络八股扩展】:数据增强_第1张图片

数据增强+网络八股代码:

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)  # 给数据增加一个维度,使数据和网络结构匹配

image_gen_train = ImageDataGenerator(
    rescale=1. / 1.,  # 如为图像,分母为255时,可归至0~1
    rotation_range=45,  # 随机45度旋转
    width_shift_range=.15,  # 宽度偏移
    height_shift_range=.15,  # 高度偏移
    horizontal_flip=True,  # 水平翻转
    zoom_range=0.5  # 将图像随机缩放阈量50%
)
image_gen_train.fit(x_train)

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(image_gen_train.flow(x_train, y_train, batch_size=32), epochs=5, validation_data=(x_test, y_test),
          validation_freq=1)
model.summary()

【神经网络八股扩展】:数据增强_第2张图片
随着迭代轮数增加,准确率不断提高。但从数据集上不能看出数据增强的效果,要在实际应用中去使用。

总结

课程链接:MOOC人工智能实践:TensorFlow笔记2

你可能感兴趣的:(#,机器学习实战,人工智能,深度学习,tensorflow,神经网络)