用PyTorch实现一个卷积神经网络进行图像分类

1. 回顾

在进入这一篇博客的内容之前,我们先确保已经成功安装好PyTorch,可以参考我之前的一篇博客“Ubuntu12.04下PyTorch详细安装记录”:

http://blog.csdn.net/wblgers1234/article/details/72902016

接下来,我们用设计一个简单的卷积神经网络的方式来熟悉PyTorch的用法。

2. 设计卷积神经网络

在设计复杂的神经网络之前,我们依然考虑按照斯坦福大学的“UFLDL Tutorial”的CNN部分来构建一个简单的卷积神经网络,即按照以下的设计:

输入层->二维特征卷积->sigmoid激励->均值池化->全连接网络->softmax输出

按照下面的代码对应来看神经网络的结构。注释得很清晰,有不清楚的可以留言,这里就不再赘述。

class CNN_net(nn.Module):
    def __init__(self):
        # 先运行nn.Module的初始化函数
        super(CNN_net, self).__init__()
        # 卷积层的定义,输入为1channel的灰度图,输出为4特征,每个卷积kernal为9*9
        self.conv = nn.Conv2d(1, 4, 9)
        # 均值池化
        self.pool = nn.AvgPool2d(2, 2)
        # 全连接后接softmax
        self.fc = nn.Linear(10*10*4, 10)
        self.softmax = nn.Softmax()

    def forward(self, x):
        # 卷积层,分别是二维卷积->sigmoid激励->池化
        out = self.conv(x)
        out = F.sigmoid(out)
        out = self.pool(out)
        print(out.size())
        # 将特征的维度进行变化(batchSize*filterDim*featureDim*featureDim->batchSize*flat_features)
        out = out.view(-1, self.num_flat_features(out))
        # 全连接层和softmax处理
        out = self.fc(out)
        out = self.softmax(out)
        return out
    def num_flat_features(self, x):
        # 四维特征,第一维是batchSize
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

3. 数据准备

还记得torchvision吗?我们在做和图像有关的实验时会更多地与它打交道。这次我们选择最简单也是最广为人知的MNIST数据库来训练和测试CNN。同时在torchvision中有一个torchvision.datasets,它为很多常用的图像数据库提供接口,其中就包括MNIST。

from torchvision.datasets import MNIST

需要先下载MNIST,并且转换为PyTorch可以识别的数据格式:

# MNIST图像数据的转换函数
trans_img = transforms.Compose([
        transforms.ToTensor()
    ])

# 下载MNIST的训练集和测试集
trainset = MNIST('./MNIST', train=True, transform=trans_img, download=True)
testset = MNIST('./MNIST', train=False, transform=trans_img, download=True)

我们查看transforms.ToTensor()的解释,将原本的二维图像格式转换为PyTorch的基本单位torch.FloatTensor。

Converts a PIL.Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].

4. 训练和测试

4.1 训练数据集

从代码中可以清晰的看见“前向传播”,“反向传播”,optimizer的求解。

# 训练过程
for i in range(epoches):
    running_loss = 0.
    running_acc = 0.
    for (img, label) in trainloader:
        # 转换为Variable类型
        img = Variable(img)
        label = Variable(label)

        optimizer.zero_grad()

        # feedforward
        output = net(img)
        loss = criterian(output, label)
        # backward
        loss.backward()
        optimizer.step()

        # 记录当前的lost以及batchSize数据对应的分类准确数量
        running_loss += loss.data[0]
        _, predict = torch.max(output, 1)
        correct_num = (predict == label).sum()
        running_acc += correct_num.data[0]

    # 计算并打印训练的分类准确率
    running_loss /= len(trainset)
    running_acc /= len(trainset)

    print("[%d/%d] Loss: %.5f, Acc: %.2f" %(i+1, epoches, running_loss, 100*running_acc))

在训练完成之后,有一个处理很重要,需要将当前的网络设置为“测试模式”,然后才可以进行测试集的验证。

# 将当前模型设置到测试模式
net.eval()

4.2 测试数据集

在测试过程中,只有“前向传播”过程对输入的图像进行分类预测。

# 测试过程
testloss = 0.
testacc = 0.
for (img, label) in testloader:
    # 转换为Variable类型
    img = Variable(img)
    label = Variable(label)

    # feedforward
    output = net(img)
    loss = criterian(output, label)

    # 记录当前的lost以及累加分类正确的样本数
    testloss += loss.data[0]
    _, predict = torch.max(output, 1)
    num_correct = (predict == label).sum()
    testacc += num_correct.data[0]

# 计算并打印测试集的分类准确率
testloss /= len(testset)
testacc /= len(testset)
print("Test: Loss: %.5f, Acc: %.2f %%" %(testloss, 100*testacc))

4.3 代码运行结果

从下面的结果,可以看到迭代10次的训练分类准确率和测试分类准确率:

CNN_net (
  (conv): Conv2d(1, 4, kernel_size=(9, 9), stride=(1, 1))
  (pool): AvgPool2d (
  )
  (fc): Linear (400 -> 10)
  (softmax): Softmax ()
)
[1/10] Loss: 1.78497, Acc: 68.79
[2/10] Loss: 1.54269, Acc: 93.10
[3/10] Loss: 1.52096, Acc: 94.93
[4/10] Loss: 1.51040, Acc: 95.82
[5/10] Loss: 1.50393, Acc: 96.45
[6/10] Loss: 1.49967, Acc: 96.77
[7/10] Loss: 1.49655, Acc: 97.02
[8/10] Loss: 1.49401, Acc: 97.24
[9/10] Loss: 1.49192, Acc: 97.45
[10/10] Loss: 1.49050, Acc: 97.56
Test: Loss: 1.48912, Acc: 97.62 %

该工程完整的代码我已经放到github上,有兴趣的可以去下载试试:

https://github.com/wblgers/stanford_dl_cnn/tree/master/PyTorch

你可能感兴趣的:(PyTorch)