PyTorch入门(四):训练一个分类器

关于数据
通常,处理图像,文本,音频或视频数据时,使用标准python包将数据加载到numpy数组中,然后将数据转换成torch.*Tensor

  • 对于图像,使用Pillow,OpenCV
  • 对于音频,使用scipy和librosa
  • 对于文本,可以使用raw Python或者基于Cython的加载,或者NLTK、SpaCy
    针对视觉处理,有一个torchvision的软件包,包含常见数据集的数据加载器,比如ImageNet,CIFAR10,MNIST等,以及用于图片的数据转换器,即torchvision.datasetstorch.utils.data.DataLoader

训练一个图像分类器

  1. 通过torchvision加载和归一化CIFAR10的训练和测试数据
  2. 定义一个卷积神经网络
  3. 定义一个损失函数
  4. 使用训练数据训练网络
  5. 使用测试数据测试网络

1.加载数据
torchvision数据集的输出是范围在[0,1]的PILImage图像,需要将其归一化为范围在[-1,1]的张量。

import torch
import torchvision
import torchvision.transforms as trans

transform = trans.Compose([trans.ToTensor(),trans.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data',train=True,
                                        download=True,transform=transform)
trainloader = torch.utils.data.Dataloader(trainset,batchsize=10,
                                          shuffle=True,num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data',train=True,
                                        download=True,transform=transform)
testloader = torch.utils.data.Dataloader(testset,batchsize=10,
                                          shuffle=True,num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

可通过如下代码显示图片

def imshow(img):
    img = img/2+0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.show()

dataiter = iter(trainloader)
images,labels = dataiter.next()

imshow(torchvision.utils.make_grid(images))
for i in range(4):
    print('%5s'%classes[labels[i]])

PyTorch入门(四):训练一个分类器_第1张图片
在这里插入图片描述
2.定义神经网络

class Net():
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=6,kernel_size=5)
        self.pool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6,16,5)
        self.fc1 = nn.Linear(in_features=16*5*5,out_features=120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)

    def forward(self,x):
        x = self.pool(f.relu(self.conv1(x)))
        x = self.pool(f.relu(self.conv2(x)))
        x = x.view(-1,16*5*5)
        x = f.relu(self.fc1(x))
        x = f.relu(self.fc2(x))
        x = self.fc3(x)
        return x

3.定义损失函数和优化方式
使用交叉熵损失函数和随机梯度下降(SGD)

criteration = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9)

4.训练网络
我们只需要遍历我们的数据迭代器,并向网络提供输入并进行优化。

for epoch in range(1):
    running_loss = 0.0
    for i,data in enumerate(trainloader,0):
        inputs,labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criteration(outputs,labels)
        loss.backward()
        optimizer.step()

        running_loss +=loss.item()
        if i%2000 ==1999:
            print('[%d,%5d] loss:%.3f'%(epoch+1,i+1,running_loss/2000))
            running_loss = 0.0

print('Finished Training!')

PyTorch入门(四):训练一个分类器_第2张图片

5.使用测试数据测试网络
通过预测神经网络输出的类标签,与正确类做对比。如果预测正确,将样本添加到正确的预测数组中。
首先显示一些测试图片

testiter = iter(testloader)
images,labels = testiter.next()

imshow(torchvision.utils.make_grid(images))
for i in range(4):
    print('%5s'%classes[labels[i]])

PyTorch入门(四):训练一个分类器_第3张图片
在这里插入图片描述
输出为10个类的概率,一个类的概率越大,网络就越认为这个图像属于特定类,所以将概率最大的类作为预测类。

outputs = net(images)
predicted = torch.argmax(outputs,1)
for i in range(4):
    print('Predicted:%5s',classes[predicted[i]])

在这里插入图片描述
通过以下代码可以得到网络对整个数据集的性能

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images,labels = data
        outputs = net(images)
        predicted = torch.argmax(outputs,1)
        total += labels.size(0)
        correct += (predicted==labels).sum().item()

print('Accuracy of the network on the 1000 test images:%d%%'%(100*correct/total))

在这里插入图片描述
通过如下代码可以看到在每一类的性能

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        predicted = torch.argmax(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

PyTorch入门(四):训练一个分类器_第4张图片

你可能感兴趣的:(PyTorch学习)