1.定义数据初始化
image_size=(224,224)
import torchvision .transforms as transforms
transform=transforms .Compose ([
transforms.ToTensor (),
transforms .RandomHorizontalFlip (),
transforms .Resize (image_size ),
transforms.Lambda(lambda x: x.repeat(3,1,1)),
transforms.Normalize((0.1307,),(0.3081,))
])
2.导入数据集
import torchvision .datasets
mnist_train=torchvision .datasets.MNIST(root='~',train=True,download=True,transform =transform )
mnist_val=torchvision .datasets .MNIST (root='~',train= False ,download= True ,transform= transform )
print(len(mnist_train.classes))
print(len(mnist_val.classes))
3.制作DataLoader
from torch .utils .data import DataLoader
trainloader=DataLoader(mnist_train ,batch_size=64,shuffle=True,num_workers=2)
valloader=DataLoader(mnist_val ,batch_size=64,shuffle=True,num_workers=2)
4.调用ResNet18 model
import torchvision .models as models
import torch
model =models.resnet18(pretrained=True)
model.fc=torch.nn.Linear(512,10)
print(model)
5.定义优化器
import torch
import torch.nn.init as init
for name,module in model._modules.items() :
if (name=='fc'):
init.kaiming_uniform_(module.weight,a=0,mode='fan_in')
6.调用GPU
device=torch.device("cuda:0" if torch.cuda.is_available() else"cpu")
print(device)
7.定义准确率函数
import torch
def accuracy(pred,target):
pred_label=torch .argmax(pred,1)
correct=sum(pred_label==target).to(torch .float )
return correct,len(pred)
8.定义字典来存放数据
acc={'train':[],"val":[]}
loss_all={'train':[],"val":[]}
9.开始训练和验证
"""设为训练模式"""
model.train()
train_correctnum,train_prednum,train_total_loss=0.,0.,0.
for images,labels in train_loader :
images,labels=images.to(device),labels.to(device)
outputs=model(images)
loss=F.cross_entropy(outputs ,labels )
optimizer.zero_grad()
train_total_loss += loss.item()
loss.backward()
optimizer .step()
correctnum,prednum=accuracy(outputs,labels )
train_correctnum += correctnum
train_prednum+=prednum
"""设为验证模式"""
model.eval()
valid_correctnum,valid_prednum,valid_total_loss=0.,0.,0.
for images,labels in valid_loader:
images,labels=images.to(device),labels.to(device)
outputs=model (images )
loss=F.cross_entropy(outputs ,labels )
valid_total_loss += loss.item()
correctnum,prednum=accuracy(outputs,labels )
valid_correctnum += correctnum
valid_prednum+=prednum
"""求平均损失"""
train_loss = train_total_loss/len(train_loader)
valid_loss = valid_total_loss/len(valid_loader)
"""将损失存入字典"""
loss_all['train'].append(train_loss )
loss_all['val'].append(valid_loss)
"""将准确率存入字典"""
acc['train'].append(train_correctnum/train_prednum)
acc['val'].append(valid_correctnum/valid_prednum)
print('train_loss:{:.6f} \t valid_loss:{:.6f}'.format(train_loss,valid_loss))
print('train_acc:{:.6f} \t valid_acc:{:.6f}'.format(train_correctnum/train_prednum,valid_correctnum/ valid_prednum))
10.训练结果
11.绘制loss 和acc曲线
import matplotlib.pyplot as plt
plt.ylim((0, 0.6))
plt.xlim((0, 10))
plt.plot(loss_all['train'] ,color='orange')
plt.plot(loss_all['val'],color='blue' )
plt.title('loss function')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()
plt.ylim((0.8, 1))
plt.xlim((0, 10))
plt.plot(acc['train'] ,color='orange')
plt.plot(acc['val'],color='blue' )
plt.title('accuracy rate')
plt.xlabel('epoch')
plt.ylabel('accuracy')
12.完整代码
"""""""""""""""""""数据初始化"""""""""""""""""""""""""""
image_size=(224,224)
import torchvision .transforms as transforms
transform=transforms .Compose ([
transforms.ToTensor (),
transforms .RandomHorizontalFlip (),
transforms .Resize (image_size ),
transforms.Lambda(lambda x: x.repeat(3,1,1)),
transforms.Normalize((0.1307,),(0.3081,))
])
"""""""""""""""""""导入数据集"""""""""""""""""""
import torchvision .datasets
mnist_train=torchvision .datasets.MNIST(root='~',train=True,download=True,transform =transform )
mnist_val=torchvision .datasets .MNIST (root='~',train= False ,download= True ,transform= transform )
print(len(mnist_train.classes))
print(len(mnist_val.classes))
"""""""""""""""制作DataLoader"""""""""""""""""""""
from torch .utils .data import DataLoader
trainloader=DataLoader(mnist_train ,batch_size=64,shuffle=True,num_workers=2)
valloader=DataLoader(mnist_val ,batch_size=64,shuffle=True,num_workers=2)
"""""""""""""""""""调用model"""""""""""""""""""
import torchvision .models as models
import torch
model =models.resnet18(pretrained=True)
model.fc=torch.nn.Linear(512,10)
print(model)
""""""""""""""""""""""调用GPU"""""""""""""""""""""""""""""
device=torch.device("cuda:0" if torch.cuda.is_available() else"cpu")
print(device)
"""""""""""""""计算准确率"""""""""""""""
import torch
def accuracy(pred,target):
pred_label=torch .argmax(pred,1)
correct=sum(pred_label==target).to(torch .float )
return correct,len(pred)
acc={'train':[],"val":[]}
loss_all={'train':[],"val":[]}
"""""""""""""""""验证和训练"""""""""""""""
model.to(device)
for epoch in range(10):
print("epoch",epoch+1,":***************************")
model.train()
train_correctnum,train_prednum,train_total_loss=0.,0.,0.
for images,labels in train_loader :
images,labels=images.to(device),labels.to(device)
outputs=model(images)
loss=F.cross_entropy(outputs ,labels )
optimizer.zero_grad()
train_total_loss += loss.item()
loss.backward()
optimizer .step()
correctnum,prednum=accuracy(outputs,labels )
train_correctnum += correctnum
train_prednum+=prednum
model.eval()
valid_correctnum,valid_prednum,valid_total_loss=0.,0.,0.
for images,labels in valid_loader:
images,labels=images.to(device),labels.to(device)
outputs=model (images )
loss=F.cross_entropy(outputs ,labels )
valid_total_loss += loss.item()
correctnum,prednum=accuracy(outputs,labels )
valid_correctnum += correctnum
valid_prednum+=prednum
"""求平均损失"""
train_loss = train_total_loss/len(train_loader)
valid_loss = valid_total_loss/len(valid_loader)
"""将损失存入字典"""
loss_all['train'].append(train_loss )
loss_all['val'].append(valid_loss)
"""将准确率存入字典"""
acc['train'].append(train_correctnum/train_prednum)
acc['val'].append(valid_correctnum/valid_prednum)
print('train_loss:{:.6f} \t valid_loss:{:.6f}'.format(train_loss,valid_loss))
print('train_acc:{:.6f} \t valid_acc:{:.6f}'.format(train_correctnum/train_prednum,valid_correctnum/ valid_prednum))
"""""""""""""""""绘图"""""""""""""""""""""
import matplotlib.pyplot as plt
plt.ylim((0, 0.6))
plt.xlim((0, 10))
plt.plot(loss_all['train'] ,color='orange')
plt.plot(loss_all['val'],color='blue' )
plt.title('loss function')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()
plt.ylim((0.8, 1))
plt.xlim((0, 10))
plt.plot(acc['train'] ,color='orange')
plt.plot(acc['val'],color='blue' )
plt.title('accuracy rate')
plt.xlabel('epoch')
plt.ylabel('accuracy')