说明
- 数据集采用的是MNIST数据集(训练集60000个, 测试集10000个,单通道28*28的图片)
- 采用的网络模型结构
- 程序在GPU上跑的。运行时
watch -n 1 nvidia-smi
实时查看电脑GPU的使用情况。
- 目录结构
测试集代码
import torch
from torch.utils import data
from torch.autograd import Variable
import torchvision
from torchvision.datasets import mnist
import matplotlib.pyplot as plt
data_tf = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.5],[0.5])
]
)
data_path = r'./mnist_data'
train_data = mnist.MNIST(data_path,train=True,transform=data_tf,download=False)
test_data = mnist.MNIST(data_path,train=False,transform=data_tf,download=False)
train_loader = data.DataLoader(train_data, batch_size=128, shuffle=True, num_workers=2)
test_loader = data.DataLoader(test_data, batch_size=128, shuffle=False, num_workers=2)
class CNNnet(torch.nn.Module):
def __init__(self):
super(CNNnet,self).__init__()
self.conv1 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=2, padding=1),
torch.nn.BatchNorm2d(16),
torch.nn.ReLU()
)
self.conv2 = torch.nn.Sequential(
torch.nn.Conv2d(16,32,3,2,1),
torch.nn.BatchNorm2d(32),
torch.nn.ReLU()
)
self.conv3 = torch.nn.Sequential(
torch.nn.Conv2d(32,64,3,2,1),
torch.nn.BatchNorm2d(64),
torch.nn.ReLU()
)
self.conv4 = torch.nn.Sequential(
torch.nn.Conv2d(64,64,2,2,0),
torch.nn.BatchNorm2d(64),
torch.nn.ReLU()
)
self.mlp1 = torch.nn.Linear(2*2*64,100)
self.mlp2 = torch.nn.Linear(100,10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.mlp1(x.view(x.size(0),-1))
x = self.mlp2(x)
return x
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = CNNnet()
model = model.to(device)
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
'''
训练网络:步骤
- 获取损失:loss = loss_func(out,batch_y)
- 清空上一步残余更新参数:opt.zero_grad()
- 误差反向传播:loss.backward()
- 将参数更新值施加到net的parmeters上:opt.step()
'''
loss_count = []
for epoch in range(10):
running_loss = 0.0
for step, (x,y) in enumerate(train_loader, 0):
inputs = Variable(x).to(device)
labels = Variable(y).to(device)
outputs = model(inputs)
loss = loss_func(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_count.append(loss)
running_loss += loss.item()
if step % 100 == 99:
print('[%d, %5d] loss: %.3f' % (epoch + 1, step + 1, running_loss / 100))
running_loss = 0.0
torch.save(model.state_dict(), './model/model_mnist.pth')
plt.figure('PyTorch_CNN_Loss')
plt.plot(loss_count,label='Loss')
plt.legend()
plt.show()
print('Finished Training')
训练集代码
import torch
from torch.utils import data
from torch.autograd import Variable
import torchvision
from torchvision.datasets import mnist
import matplotlib.pyplot as plt
is_support = torch.cuda.is_available()
if is_support:
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
data_tf = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
test_data = mnist.MNIST('mnist_data',train=False,transform=data_tf,download=False)
test_loader = data.DataLoader(test_data, batch_size=128, shuffle=False, num_workers=2)
class CNNnet(torch.nn.Module):
def __init__(self):
super(CNNnet,self).__init__()
self.conv1 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=2, padding=1),
torch.nn.BatchNorm2d(16),
torch.nn.ReLU()
)
self.conv2 = torch.nn.Sequential(
torch.nn.Conv2d(16,32,3,2,1),
torch.nn.BatchNorm2d(32),
torch.nn.ReLU()
)
self.conv3 = torch.nn.Sequential(
torch.nn.Conv2d(32,64,3,2,1),
torch.nn.BatchNorm2d(64),
torch.nn.ReLU()
)
self.conv4 = torch.nn.Sequential(
torch.nn.Conv2d(64,64,2,2,0),
torch.nn.BatchNorm2d(64),
torch.nn.ReLU()
)
self.mlp1 = torch.nn.Linear(2*2*64,100)
self.mlp2 = torch.nn.Linear(100,10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.mlp1(x.view(x.size(0),-1))
x = self.mlp2(x)
return x
model = CNNnet()
model.load_state_dict(torch.load('./model/model_mnist.pth'))
model.to(device)
correct = 0
total = 0
count = 0
with torch.no_grad():
for images, labels in test_loader:
images = Variable(images).to(device)
labels = Variable(labels).to(device)
pre_labels = model(images)
_, pred = torch.max(pre_labels, 1)
correct += (pred == labels).sum().item()
total += labels.size(0)
count += 1
print("在第{0}个batch中的Acc为:{1}" .format(count, correct/total))
accuracy = float(correct) / total
print("====================== Result =============================")
print('测试集上平均Acc = {:.5f}'.format(accuracy))
print("测试集共样本{0}个,分为{1}个batch,预测正确{2}个".format(total, count, correct))
结果
参考:
- 使用PyTorch实现CNN
- pytorch学习(十)—训练并测试CNN网络