神经网络训练模型的过程中,如果程序突然中断,竹篮打水一场空?
>>>断点续训来解决!
目录
(1)Pytorch框架的断点续训(猫狗大战)
(2)Tensorflow框架的断点续训(鸢尾花分类)
由于目前最流行的深度学习的框架是 Pytorch 和 Tensorflow,故此文章针对这两种框架实现代码的断点续训。为了方便一键运行程序,运行程序之前,需要从 Kaggle 数据集网站上面下载 ‘猫狗大战’ 的数据集 Cat and Dog | Kaggle,也可以用自己的数据集,或者可以调用框架中封装好的函数,自动的下载数据集,但是这种方法不提倡,最好还是用自己的数据集。
本文代码为纯实战代码,注释都在代码当中,希望对你有所帮助。
import os
import time
import PIL
import torch
import numpy as np
import cv2
import torch.nn as nn
# np.set_printoptions(threshold=np.inf) # 打印所有的数据
from matplotlib import pyplot as plt
from tqdm import tqdm
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('调用设备:', device)
# =========================加载数据集========================== #
# 统一图像尺寸为(320*320),数据集的维度尺寸为(data_number,320,320,3)
train_cat_path = '../data/cat_dog/training_set/cats'
val_cat_path = '../data/cat_dog/val_set/cats'
train_dog_path = '../data/cat_dog/training_set/dogs'
val_dog_path = '../data/cat_dog/val_set/dogs'
train_cat_images = []
val_cat_images = []
train_dog_images = []
val_dog_images = []
temp_flag = 0
stop_flag = 4
train_cat_files = os.listdir(train_cat_path)
val_cat_files = os.listdir(val_cat_path)
train_dog_files = os.listdir(train_dog_path)
val_dog_files = os.listdir(val_dog_path)
print('=' * 80)
print('加载数据集')
for train_cat_file in tqdm(train_cat_files):
# 处理cat数据集
train_cat_image = cv2.imread(train_cat_path + '/' + train_cat_file)
train_cat_image = cv2.resize(train_cat_image, (224, 224), dst=cv2.INTER_LINEAR)
train_cat_image = np.transpose(train_cat_image, (2, 0, 1))
# print(train_cat_image.shape)
train_cat_images.append(torch.from_numpy(train_cat_image / 255))
# temp_flag += 1
# if temp_flag == stop_flag:
# temp_flag = 0
# break
for train_dog_file in tqdm(train_dog_files):
# 处理dog数据集
train_dog_image = cv2.imread(train_dog_path + '/' + train_dog_file)
train_dog_image = cv2.resize(train_dog_image, (224, 224), dst=cv2.INTER_LINEAR)
train_dog_image = np.transpose(train_dog_image, (2, 0, 1))
# print(train_dog_image.shape)
train_dog_images.append(torch.from_numpy(train_dog_image / 255))
# temp_flag += 1
# if temp_flag == stop_flag:
# temp_flag = 0
# break
for val_cat_file in tqdm(val_cat_files):
# 处理cat数据集
val_cat_image = cv2.imread(val_cat_path + '/' + val_cat_file)
val_cat_image = cv2.resize(val_cat_image, (224, 224), dst=cv2.INTER_LINEAR)
val_cat_image = np.transpose(val_cat_image, (2, 0, 1))
# print(val_cat_image.shape)
val_cat_images.append(torch.from_numpy(val_cat_image / 255))
# temp_flag += 1
# if temp_flag == stop_flag:
# temp_flag = 0
# break
for val_dog_file in tqdm(val_dog_files):
# 处理dog数据集
val_dog_image = cv2.imread(val_dog_path + '/' + val_dog_file)
val_dog_image = cv2.resize(val_dog_image, (224, 224), dst=cv2.INTER_LINEAR)
val_dog_image = np.transpose(val_dog_image, (2, 0, 1))
# print(val_dog_image.shape)
val_dog_images.append(torch.from_numpy(val_dog_image / 255))
# temp_flag += 1
# if temp_flag == stop_flag:
# temp_flag = 0
# break
# stack把块拼到一起并扩展维度
train_cat_data = torch.stack(train_cat_images, dim=0)
train_dog_data = torch.stack(train_dog_images, dim=0)
val_cat_data = torch.stack(val_cat_images, dim=0)
val_dog_data = torch.stack(val_dog_images, dim=0)
# print(train_cat_data.size())
# # 合成最终的数据集
# cat把张量沿着dim维度拼接到一起,不会扩展维度
train_data = torch.cat((train_cat_data, train_dog_data), dim=0)
# print(train_data.size())
val_data = torch.cat((val_cat_data, val_dog_data), dim=0)
# print(val_data.size())
# 定义标签张量
train_labels_data = torch.cat((torch.zeros(1, len(train_cat_data))[0], torch.ones(1, len(train_dog_data))[0]), dim=-1)
val_labels_data = torch.cat((torch.zeros(1, len(val_cat_data))[0], torch.ones(1, len(val_dog_data))[0]), dim=-1)
print(train_labels_data.size())
print(val_labels_data.size())
# 随机发乱数据
# 生成随机种子(一个整数),设置后保证每次生成的随机数相同
# 保证图像和标签一一对应
np.random.seed(66)
np.random.shuffle(train_data)
np.random.seed(66)
np.random.shuffle(train_labels_data)
np.random.seed(6)
np.random.shuffle(val_data)
np.random.seed(6)
np.random.shuffle(val_labels_data)
print(val_labels_data)
# # 批量划分数据集
# 定义 batch_size 大小
batch_size = 32
# https://zhuanlan.zhihu.com/p/184911006
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=False)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False, drop_last=False)
# print(train_loader)
train_steps = len(train_data)
print('=' * 80)
# =========================搭建VGG网络========================== #
num_classes = 2
init_weights = True
class VGG16_model(nn.Module):
def __init__(self):
super(VGG16_model, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
# 构建分类网络结构
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096), # 第一层全连接层
nn.ReLU(True),
nn.Dropout(p=0.5), # 50%的比例随机失活
nn.Linear(4096, 4096), # 第二层全连接层
nn.ReLU(True),
nn.Dropout(p=0.5),
nn.Linear(4096, num_classes) # 第三层全连接层
)
# 是否进行权重初始化
if init_weights:
self._initialize_weights()
# self.add_module('Linear1', nn.Linear(512, 512))
# self.add_module('Dropout1', nn.Dropout(0.2))
# self.add_module('Linear2', nn.Linear(512, 512))
# self.add_module('Dropout2', nn.Dropout(0.2))
# self.add_module('Linear3', nn.Linear(512, 2)) # 猫狗大战(两类)
# self.add_module('Dropout3', nn.Dropout(0.2))
def forward(self, x):
# print(x.size(0)) # 打印批量大小
# 输入 [D, C, H ,W] ,即批量大小、通道数、高、宽
x = self.features(x)
x = torch.flatten(x, start_dim=1)
x = self.classifier(x)
return x
def _initialize_weights(self):
# Returns an iterator over all modules in the network
for m in self.modules():
if isinstance(m, nn.Conv2d):
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
# nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
# from torchsummary import summary
model = VGG16_model()
# summary(model, (3, 244, 244))
# print(model)
# ========================加载预训练模型========================= #
# model_weight_path = '../pre_weights/vgg16-397923af.pth'
# assert os.path.exists(model_weight_path), 'model_weight_path {} does nit exist'.format(model_weight_path)
# pre_weights = torch.load(model_weight_path, map_location=device)
# model.load_state_dict(pre_weights, strict=False)
# =========================优化算法========================== #
model.to(device) # 把模型加载到设备当中
# # 损失函数
# 该标准计算输入和目标之间的交叉熵损失。
loss_function = nn.CrossEntropyLoss() # 记得加括号
# # 优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
# =========================实时保存pth========================== #
temp_path = '../save_weights/temp_pth/VGG16_last.pth'
start_epoch = 0
if os.path.isfile(temp_path):
RESUME = True
else:
RESUME = False
print(RESUME)
if RESUME:
path_pth = temp_path # 断点路径
pth = torch.load(path_pth) # 加载断点
model.load_state_dict(pth['model']) # 加载模型可学习参数
optimizer.load_state_dict(pth['optimizer']) # 加载优化器参数
start_epoch = pth['epoch'] # 设置开始的epoch
# scheduler = pth['lr_scheduler'] # 加载lr_scheduler
print("成功加载中断时保存的模型")
# =========================网络训练========================== #
# 初始化曲线图的参数
acc_plt = []
loss_plt = []
loss = 0
epochs = 80
start_time = time.time()
for epoch in range(start_epoch, epochs):
start_time1 = time.time()
model.train()
running_loss = 0.0
train_bar = tqdm(train_loader)
for step, data in enumerate(train_bar):
images = data.to(torch.float32)
# print(images.size())
train_labels = train_labels_data[step * batch_size: (step+1) * batch_size].long()
# print(train_labels)
optimizer.zero_grad() # 将所有优化的 Torch.Tensor 的梯度设置为零。
logits = model(images.to(device)) # 把图像批量送入到模型中
# print(logits.size()) # 赋予每个类的手动重新缩放权重
loss = loss_function(logits, train_labels.to(device))
loss.backward()
optimizer.step()
# 统计损失值
running_loss += loss.item()
# 打印指定的进度条信息
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)
# 每一个epoch进行一次验证
model.eval()
acc = 0.0
with torch.no_grad():
val_bar = tqdm(val_loader)
for step, data in enumerate(val_bar):
val_images = data.to(torch.float32)
val_labels = val_labels_data[step * batch_size: (step + 1) * batch_size].long()
# 把验证图像送入到模型中进行预测,选取预测概率最大的结果
outputs = model(val_images.to(device))
predict_y = torch.max(outputs, dim=1)[1]
# 更新正确的个数: 每一个batch_size中的预测正确的个数
# torch.eq:对两个张量Tensor进行逐元素的比较,若相同位置的两个元素相同,则返回True(1);若不同,返回False(0)
# .sum(): 求和
# .item(): tensor(0.66) --> 0.65895
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1, epochs)
# 统计所有验证集的正确个数后,计算正确率
val_accurate = acc / len(val_data)
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' % (epoch + 1, running_loss/train_steps, val_accurate))
# los = running_loss / train_steps
# scheduler.step(los)
end_time1 = time.time()
cost_time1 = end_time1 - start_time1
print("训练一次的时间:", cost_time1)
# =======================保存模型========================= #
pth = {
"model": model.state_dict(),
'optimizer': optimizer.state_dict(),
"epoch": epoch,
}
torch.save(pth, temp_path)
print("实时模型保存完毕")
# 在训练过程中每个多少个epoch保存一次网络参数,便于恢复,提高程序的鲁棒性
# =======================保存信息========================= #
# a: 不覆盖
# w: 覆盖
with open('VGG16.txt', mode='a', encoding='utf-8') as f:
f.write('epoch:{} acc:{} loss:{} cost_time:{}'.format(epoch, val_accurate, loss, cost_time1))
f.write('\n')
# 可视化列表
acc_plt.append(val_accurate)
loss_plt.append(running_loss)
end_time = time.time()
cost_time = end_time - start_time
print("总训练时间", cost_time)
import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
from tensorflow.keras import Model
np.set_printoptions(threshold=np.inf) # 打印矩阵的所有值(即消除省略号)
# 加载鸢尾花数据集
cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
class VGG16(Model):
def __init__(self):
super(VGG16, self).__init__()
self.c1 = Conv2D(filters=64, kernel_size=(3, 3), padding='same') # 卷积层1
self.b1 = BatchNormalization() # BN层1
self.a1 = Activation('relu') # 激活层1
self.c2 = Conv2D(filters=64, kernel_size=(3, 3), padding='same', )
self.b2 = BatchNormalization() # BN层1
self.a2 = Activation('relu') # 激活层1
self.p1 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
self.d1 = Dropout(0.2) # dropout层
self.c3 = Conv2D(filters=128, kernel_size=(3, 3), padding='same')
self.b3 = BatchNormalization() # BN层1
self.a3 = Activation('relu') # 激活层1
self.c4 = Conv2D(filters=128, kernel_size=(3, 3), padding='same')
self.b4 = BatchNormalization() # BN层1
self.a4 = Activation('relu') # 激活层1
self.p2 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
self.d2 = Dropout(0.2) # dropout层
self.c5 = Conv2D(filters=256, kernel_size=(3, 3), padding='same')
self.b5 = BatchNormalization() # BN层1
self.a5 = Activation('relu') # 激活层1
self.c6 = Conv2D(filters=256, kernel_size=(3, 3), padding='same')
self.b6 = BatchNormalization() # BN层1
self.a6 = Activation('relu') # 激活层1
self.c7 = Conv2D(filters=256, kernel_size=(3, 3), padding='same')
self.b7 = BatchNormalization()
self.a7 = Activation('relu')
self.p3 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
self.d3 = Dropout(0.2)
self.c8 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
self.b8 = BatchNormalization() # BN层1
self.a8 = Activation('relu') # 激活层1
self.c9 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
self.b9 = BatchNormalization() # BN层1
self.a9 = Activation('relu') # 激活层1
self.c10 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
self.b10 = BatchNormalization()
self.a10 = Activation('relu')
self.p4 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
self.d4 = Dropout(0.2)
self.c11 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
self.b11 = BatchNormalization() # BN层1
self.a11 = Activation('relu') # 激活层1
self.c12 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
self.b12 = BatchNormalization() # BN层1
self.a12 = Activation('relu') # 激活层1
self.c13 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
self.b13 = BatchNormalization()
self.a13 = Activation('relu')
self.p5 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
self.d5 = Dropout(0.2)
self.flatten = Flatten()
self.f1 = Dense(512, activation='relu')
self.d6 = Dropout(0.2)
self.f2 = Dense(512, activation='relu')
self.d7 = Dropout(0.2)
self.f3 = Dense(10, activation='softmax')
def call(self, x):
x = self.c1(x)
x = self.b1(x)
x = self.a1(x)
x = self.c2(x)
x = self.b2(x)
x = self.a2(x)
x = self.p1(x)
x = self.d1(x)
x = self.c3(x)
x = self.b3(x)
x = self.a3(x)
x = self.c4(x)
x = self.b4(x)
x = self.a4(x)
x = self.p2(x)
x = self.d2(x)
x = self.c5(x)
x = self.b5(x)
x = self.a5(x)
x = self.c6(x)
x = self.b6(x)
x = self.a6(x)
x = self.c7(x)
x = self.b7(x)
x = self.a7(x)
x = self.p3(x)
x = self.d3(x)
x = self.c8(x)
x = self.b8(x)
x = self.a8(x)
x = self.c9(x)
x = self.b9(x)
x = self.a9(x)
x = self.c10(x)
x = self.b10(x)
x = self.a10(x)
x = self.p4(x)
x = self.d4(x)
x = self.c11(x)
x = self.b11(x)
x = self.a11(x)
x = self.c12(x)
x = self.b12(x)
x = self.a12(x)
x = self.c13(x)
x = self.b13(x)
x = self.a13(x)
x = self.p5(x)
x = self.d5(x)
x = self.flatten(x)
x = self.f1(x)
x = self.d6(x)
x = self.f2(x)
x = self.d7(x)
y = self.f3(x)
return y
model = VGG16() # 定义VGG16网络
# 定义优化器和损失函数
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
# checkpoint保存路径,如果路径存在,则加载checkpoint
checkpoint_save_path = "./checkpoint/VGG16.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
print('-------------load the model-----------------')
model.load_weights(checkpoint_save_path)
# 召回checkpoint文件
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
save_weights_only=True,
save_best_only=True)
# 定义训练、测试数据和一些超参数
# callbacks=[cp_callback]:断点续训的灵魂,是否召回已存在的checkpoint
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
callbacks=[cp_callback])
model.summary() # 模型信息的总结
# print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.numpy()) + '\n')
file.close()
############################################### show ###############################################
# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
>>>如有疑问,欢迎评论区一起探讨。