tensorflow2.0学习笔记——模型改进扩展

  • 本文内容为北大tensorflow笔记的课堂笔记

1.图像数据增强

  • 数据增强就是扩展数据,对图像的数据增强就是对图片进行简单的变形。
image_ten_train = tf.keras.preprocessing.image.ImageDataGenerator(
	rescale=所有数据将乘以该值
	rotation_range=随机旋转角度数范围,用来对图像进行随机旋转
	width_shift_range=随机宽度偏移量
	height_shift_range=随机高度偏移量
	水平翻转:horizontal_flip=是否随机翻转
	随机缩放:zoom_range=随机缩放的范围[1-n, 1+n])
  • 应用举例:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# 定义增强方法
image_gen_train = ImageDataGenerator(
    rescale=1. / 1.,  # 如为图像,分母为255时,可归至0~1,即rescale=1./255.
    rotation_range=45,  # 随机45度旋转
    width_shift_range=.15,  # 宽度偏移
    height_shift_range=.15,  # 高度偏移
    horizontal_flip=False,  # 水平翻转
    zoom_range=0.5  # 将图像随机缩放阈量50%
)
image_gen_train.fit(x_train) # 需要进行一次fit操作,来对数据进行增强

2.断点续训

  • 断电续训用于存取模型

2.1 读取模型

可以使用tensorflow中的load_weight(路径文件名)来读取模型。
例如:

checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
# 因为生成ckpt文件时会同步生成索引表,所以可以通过判断是不是已经有了索引表,
# 来判断是否已经保存模型
	print('--------load the model--------')
	model.load_weights(checkpoint_save_path)

2.2 保存模型

可以使用tensorflow的回调函数来保存模型参数

tf.keras.callbacks.ModelCheckpoint(
	filepath=路径文件名, 
	save_weights_only=True/False,
	# 是否只保留模型参数
	save_best_only=True/False)
	# 是否只保留最优结果

history=model.fit( callbacks=[cp_callback])
# 需要在执行训练过程时加入callbacks选项

3.参数提取

  • 将保存的模型参数存入文本
  • 可用model.trainable_variable 返回模型中可训练的参数,可用print直接打印出这些参数,不过直接使用print中间会有很多数据被省略号替换,可以使用**np.set_printoptions(threshold=超过多少省略显示)**来设置打印效果。
  • 示例
np,set_printoptions(threshold=np.inf)
print(model.trainable_variables) # 打印所有可训练参数

file = open('.weight.txt', 'w')
# 将可训练参数存入文本中
for v in model.trainable_variables:
	file.write(str(v.name) + '\n')
	file.write(str(v.shape) + '\n')
	file.write(str(v.numpy()) + '\n')

你可能感兴趣的:(深度学习基础)