pytorch 搭建AlexNet 对花进行分类

目录

1. 介绍

2. 搭建AlexNet网络

3. 准备数据集

4. 训练网络

5. 预测图片

6. code


文章内容参考:霹雳吧啦Wz 的视频教程

代码的讲解可以参考之前的文章:pytorch 搭建 LeNet 网络对 CIFAR-10 图片分类

1. 介绍

AlexNet 网络的结构为:

pytorch 搭建AlexNet 对花进行分类_第1张图片

卷积层的计算公式为:

pytorch 搭建AlexNet 对花进行分类_第2张图片

 通过计算,可以得到网络之间的参数为:

pytorch 搭建AlexNet 对花进行分类_第3张图片

 

2. 搭建AlexNet网络

之前介绍过,卷积层相当于特征提取、全连接层相当于分类器

所以这里分开搭建不同的模块

特征提取的部分:

pytorch 搭建AlexNet 对花进行分类_第4张图片

  •  这里因为训练的数据较少,所以卷积核的数目都降为了一半
  • Sequential 是一个特殊的Module ,它包含了几个子module,前向传播时会将输入一层接着一层的传递下去
  • nn.ReLU(inplace = True)  , ReLU有个inplace参数,设置为Ture 的时候,它会把输出之间覆盖到输入中,这样可以节省资源。因为ReLU计算反向传播的时候,只需要根据输出就能反推出反向传播的梯度。(ReLU 反向传来的梯度会传给输入为正的部分

分类的部分:

pytorch 搭建AlexNet 对花进行分类_第5张图片

  •  AlexNet 使用了dropout 随机失活来防止过拟合,这是针对于全连接层而言的,所以要在每个全连接层的前面加上Dropout
  • num_classes 是最后分类的个数

前向传播的部分:

pytorch 搭建AlexNet 对花进行分类_第6张图片

  •  因为出了特征提取层后,数据的size为(128,6,6),具体来说是n*128*6*6,这里的n是batch_size ,所以我们只对后面三个维度做 flatten ,例如:

打印的网络结构为:

pytorch 搭建AlexNet 对花进行分类_第7张图片

 

3. 准备数据集

这里对五个不同属性的花做分类,都放在了flower_data 下,分别是:雏菊、蒲公英、玫瑰、向日葵、郁金香

通过split_data 文件对flower_data 划分出训练集和验证集,比例为9:1

这里的目录顺序不能错

pytorch 搭建AlexNet 对花进行分类_第8张图片

 

4. 训练网络

因为大部分代码是重合的,所以只做少量介绍,具体的可以看之前的文章:pytorch 搭建 LeNet 网络对 CIFAR-10 图片分类

首先是定义数据预处理函数,针对训练集和验证集定义不同的预处理

这里因为样本不够,所以随机裁剪以及随机翻转做数据增强

ToTensor 是归一化和改变通道顺序

pytorch 搭建AlexNet 对花进行分类_第9张图片


然后是载入数据,训练集和验证集等等

pytorch 搭建AlexNet 对花进行分类_第10张图片


接下来显示一下图片,需要把validate_loader 里面的batch_size 改成4,然后将shuffle 改为True,否则全部都是一种类型的图片

pytorch 搭建AlexNet 对花进行分类_第11张图片

打印的label 为

 显示的图像为:

pytorch 搭建AlexNet 对花进行分类_第12张图片


 接下来实例化网络和定义优化器:

这里因为网络结构较大,训练时间长,可以设置一个准确率最好的参数(best_acc)用来实时保存最好准确率的那个权重参数


然后开始训练网络:

pytorch 搭建AlexNet 对花进行分类_第13张图片

这里的net.train() 可以管理dropout方法,相当于开启dropout 


最后就是计算准确率:

pytorch 搭建AlexNet 对花进行分类_第14张图片

net.eval() 用来关闭dropout 方法

最后保存网络的时候,我们根据best_acc 去保存最优的那个网络参数

torch.max 那里,dim = 1代表对第一个维度求取最大值,保留第零个维度,因为第零个维度是batch_size。然后后面的[1]因为torch.max会返回值、索引,这里我们只需要索引

5. 预测图片

预测的代码基本上没有变化,就是满足几个步骤即可

 1. 将下载的图片进行预处理,这里的预处理要和之前的训练的预处理一样,并且要多一个将size改变成标准的输入size

2. 增加维度,因为图片是3通道的,而我们输入多了一个batch_size 维度,所以通过unsqueeze增加一个维度

3. 加载网络参数

4. 做预测,读取最大的那个预测概率

pytorch 搭建AlexNet 对花进行分类_第15张图片

6. code

搭建AlexNet 网络结构:

import torch.nn as nn
import torch

class AlexNet(nn.Module):       # 继承nn.Module 父类
    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        # 提取图像的特征
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # (input[3, 224, 224]  output[48, 55, 55]
            nn.ReLU(inplace=True),        # 会自动舍去小数部分,将最后一行和一列舍去,等价于左补2,右补1,上补2,下补1
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )
        # 对特征分类
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),          # 随机失活,对全连接层操作
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)   # batch * channel * height * width ,第一个batch不变
        x = self.classifier(x)
        return x

# net = AlexNet()
# print(net)

 训练网络部分:

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import AlexNet

data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224), # 随机裁剪成224 *224
                                transforms.RandomHorizontalFlip(), # 水平方向随机翻转
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}


train_dataset = datasets.ImageFolder(root="./flower_data/train",transform=data_transform["train"])   # 读取训练集
validate_dataset = datasets.ImageFolder(root='./flower_data/val',transform=data_transform["val"])     # 读取验证集

classes = ("daisy", "dandelion", "roses", "sunflowers", "tulips")   # 雏菊、蒲公英、玫瑰、向日葵、郁金香

batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)     # 载入训练集
validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=0)   # 载入验证集

# #显示图像 code
# test_data_iter = iter(validate_loader)
# test_image, test_label = test_data_iter.next()
#
# def imshow(img):
#     img = img / 2 + 0.5  # unnormalize
#     npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1, 2, 0)))
#     plt.show()
#
# print(' '.join('%5s' % classes[test_label[j].item()] for j in range(4)))
# imshow(utils.make_grid(test_image))


net = AlexNet(num_classes=5)         # 实例化网络
loss_function = nn.CrossEntropyLoss()                   # 定义交叉熵损失函数
optimizer = optim.Adam(net.parameters(), lr=0.0002)     # 定义优化器

save_path = './AlexNet.pth'         #  网络保存的路径
best_acc = 0.0                      # 保存最好准确率的model

for epoch in range(10):     # 训练次数
    net.train()             # 管理dropout 方法,在训练的时候随机失活
    running_loss = 0.0
    for step, data in enumerate(train_loader,start=0):
        images, labels = data
        optimizer.zero_grad()                   # 梯度清零
        outputs = net(images)                   # 前向传播
        loss = loss_function(outputs, labels)   # 计算损失函数
        loss.backward()                         # 反向传播
        optimizer.step()                        # 更新权重

        running_loss += loss.item()

    # validate
    net.eval()      # 关闭dropout
    acc = 0.0  # accumulate accurate number / epoch
    total = 0
    with torch.no_grad():
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images)                    # 网络预测
            predicted= torch.max(outputs, dim=1)[1]
            acc += (predicted == val_labels).sum().item()    # 计算准确率
            total += val_labels.size(0)

    print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / step, 100 * acc / total))

    if (acc / total) > best_acc:                # 保存当前最好的准确率
        best_acc = acc / total
        torch.save(net.state_dict(), save_path)

print('Finished Training')

预测图片部分:

import torch
from PIL import Image
from torchvision import transforms
from model import AlexNet

data_transform = transforms.Compose([transforms.Resize((224, 224)),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
img = Image.open("./tulips.png")            # 载入图像
img = data_transform(img)
img = torch.unsqueeze(img, dim=0)     # 增加维度,第0维增加1 ,维度(1,C,H,W)

classes = ( "daisy","dandelion","roses","sunflowers","tulips")  # 5个分类label

model = AlexNet(num_classes=5)      # 实例化网络
model.load_state_dict(torch.load("./AlexNet.pth"))  # 读取保存的网络参数
model.eval()                        # 关闭dropout 方法

with torch.no_grad():           # 预测不需要计算梯度
    output = model(img)
    predict = torch.max(output, dim=1)[1]
    print(classes[int(predict)])



训练网络打印的信息为:

pytorch 搭建AlexNet 对花进行分类_第16张图片

输入的预测图片为:

pytorch 搭建AlexNet 对花进行分类_第17张图片

 预测的结果为:

 


如果想要分类自己的分类目标的话,只需要将flower_data 里面的图片改成自己的就行了

然后用split_data 划分一下数据就行

注: 目录就是labels ,顺序不能错。目录的顺序需要和这个保持一致

 

你可能感兴趣的:(Neural,network,pytorch,分类,深度学习,人工智能)