在训练比较大、耗时较久的网络时,如果突然停电、断网或者一些意外情况发生导致训练中断,那么已经训练好的内容可能全部丢失,这时我们就需要在训练过程中把一些时间点的checkpoint保存下来,及时训练意外中断,那么我们也可以在之后把这些checkpoint下载下来,重新开始训练。
(谁能想到我刚刚码好这段话就停电了呢????)
cifar-10+resnet.
一样,重点在load_state_dict的,可以直接跳转:
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
trans = transforms.Compose((transforms.Resize(32),transforms.ToTensor()))
cifar_train = datasets.CIFAR10('cifar',train = True,transform=trans)
cifar_train_batch = DataLoader(cifar_train,batch_size=30,shuffle=True)
cifar_test = datasets.CIFAR10('cifar',train = False,transform=trans)
cifar_test_batch = DataLoader(cifar_test,batch_size=30,shuffle=True)
#搭建resnet
class resblock(nn.Module):
def __init__(self,ch_in,ch_out,stride):
super(resblock,self).__init__()
self.conv_1 = nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)
self.bn_1 = nn.BatchNorm2d(ch_out)
self.conv_2 = nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)
self.bn_2 = nn.BatchNorm2d(ch_out)
self.ch_in,self.ch_out,self.stride = ch_in,ch_out,stride
self.ch_trans = nn.Sequential()
if ch_in != ch_out:
self.ch_trans = nn.Sequential(nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),nn.BatchNorm2d(self.ch_out))
#ch_trans表示通道数转变。因为要做short_cut,所以x_pro和x_ch的size应该完全一致
def forward(self,x):
x_pro = F.relu(self.bn_1(self.conv_1(x)))
x_pro = self.bn_2(self.conv_2(x_pro))
#short_cut:
x_ch = self.ch_trans(x)
out = x_pro + x_ch
return out
class resnet(nn.Module):
def __init__(self):
super(resnet,self).__init__()
self.conv_1 = nn.Sequential(
nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(64))
self.block1 = resblock(64,128,2) #长宽减半 32/2=16
self.block2 = resblock(128,256,2) #长宽再减半 16/2=8
self.block3 = resblock(256,512,1)
self.block4 = resblock(512,512,1)
self.outlayer = nn.Linear(512,10) #512*8*8=32768
def forward(self,x):
x = F.relu(self.conv_1(x))
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = F.adaptive_avg_pool2d(x,[1,1])
x = x.reshape(x.size(0),-1)
result = self.outlayer(x)
return result
device = torch.device('cuda')
net = resnet()
net = net.to(device)
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(net.parameters(),lr=1e-3)
#开始训练
for epoch in range(5):
for batchidx,(x,label) in enumerate(cifar_train_batch):
x,label = x.to(device),label.to(device) #x.size (bcs,3,32,32) label.size (bcs)
logits = net.forward(x)
loss = loss_fn(logits,label) #logits.size:bcs*10,label.size:bcs
#开始反向传播:
optimizer.zero_grad()
loss.backward() #计算gradient
optimizer.step() #更新参数
if (batchidx+1)%400 == 0:
print('这是本次迭代的第{}个batch'.format(batchidx+1)) #本例中一共有50000张照片,每个batch有30张照片,所以一个epoch有1667个batch
'''
就是这里!!!每400个batch就存一次checkpoint,
存到指定的文件,这里我设的是一个TXT文件
'''
torch.save(net.state_dict(),"./resnet_ckp.txt")
print('这是第{}迭代,loss是{}'.format(epoch+1,loss.item()))
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第1迭代,loss是0.7839802503585815
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第2迭代,loss是1.0195786952972412
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第3迭代,loss是0.5244616866111755
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第4迭代,loss是0.6468905210494995
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第5迭代,loss是0.8967750668525696
net.eval()
with torch.no_grad():
correct_num = 0
total_num = 0
batch_num = 0
for x,label in cifar_test_batch: #x的size是30*3*32*32(30是batch_size,3是通道数),label的size是30.
#cifar_test中一共有10000张照片,所以一共有334个batch,因此要循环334次
x,label = x.to(device),label.to(device)
logits = net.forward(x)
pred = logits.argmax(dim=1)
correct_num += torch.eq(pred,label).float().sum().item()
total_num += x.size(0)
batch_num += 1
if batch_num%50 == 0:
print('这是测试集上的第{}个batch'.format(batch_num)) #一共有10000/30≈334个batch
acc = correct_num/total_num #最终的total_num是10000
print('测试集上的准确率为:',acc)
这是测试集上的第50个batch
这是测试集上的第100个batch
这是测试集上的第150个batch
这是测试集上的第200个batch
这是测试集上的第250个batch
这是测试集上的第300个batch
测试集上的准确率为: 0.7628
continue_net = resnet().to(device)
para_in_last_net = torch.load('./resnet_ckp.txt') #把之前网络的参数下载到para_in_last_net中
continue_net.load_state_dict(para_in_last_net) #把para_in_last_net加载到continue_net中
#其实这两步可以合到一起写 continue_net.load_state_dict(torch.load('./resnet_ckp.txt'))
#然后我们再训练2个epoch(这次我们就不保存checkpoint了)
for epoch in range(2):
for batchidx,(x,label) in enumerate(cifar_train_batch):
x,label = x.to(device),label.to(device) #x.size (bcs,3,32,32) label.size (bcs)
logits = continue_net.forward(x)
loss = loss_fn(logits,label) #logits.size:bcs*10,label.size:bcs
#开始反向传播:
optimizer.zero_grad()
loss.backward() #计算gradient
optimizer.step() #更新参数
if (batchidx+1)%400 == 0:
print('这是本次迭代的第{}个batch'.format(batchidx+1)) #本例中一共有50000张照片,每个batch有30张照片,所以一个epoch有1667个batch
# torch.save(net.state_dict(),"./resnet_ckp.txt")
print('这是第{}迭代,loss是{}'.format(epoch+1,loss.item()))
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第1迭代,loss是0.809099018573761
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第2迭代,loss是0.43857383728027344
continue_net.eval()
with torch.no_grad():
correct_num = 0
total_num = 0
batch_num = 0
for x,label in cifar_test_batch: #x的size是30*3*32*32(30是batch_size,3是通道数),label的size是30.
#cifar_test中一共有10000张照片,所以一共有334个batch,因此要循环334次
x,label = x.to(device),label.to(device)
logits = continue_net.forward(x)
pred = logits.argmax(dim=1)
correct_num += torch.eq(pred,label).float().sum().item()
total_num += x.size(0)
batch_num += 1
if batch_num%50 == 0:
print('这是测试集上的第{}个batch'.format(batch_num)) #一共有10000/30≈334个batch
acc = correct_num/total_num #最终的total_num是10000
print('测试集上的准确率为:',acc)
这是测试集上的第50个batch
这是测试集上的第100个batch
这是测试集上的第150个batch
这是测试集上的第200个batch
这是测试集上的第250个batch
这是测试集上的第300个batch
测试集上的准确率为: 0.7754