AlexNet网络详解(实现花的种类识别)

AlexNet

AlexNet是2012年ISLVRC 2012(ImageNet Large Scale Visual Recognition Challenge)竞赛的冠军网络,分类准确率由传统的 70%+提升到 80%+。它是由Hinton和他的学生Alex Krizhevsky设计的。也是在那年之后,深度学习开始迅速发展。

AlexNet的亮点

1.AlexNet在激活函数上选取了非线性非饱和的relu函数,在训练阶段梯度衰减快慢方面,relu函数比传统神经网络所选取的非线性饱和函数(如sigmoid函数,tanh函数)要快许多。

2.AlexNet在双gpu上运行,每个gpu负责一半网络的运算。

3.采用局部响应归一化(LRN)。对于非饱和函数relu来说,不需要对其输入进行标准化,但Alex等人发现,在relu层加入LRN,可形成某种形式的横向抑制,从而提高网络的泛华能力。

4.池化方式采用overlapping pooling。即池化窗口的大小大于步长,使得每次池化都有重叠的部分。这种重叠的池化方式比传统无重叠的池化方式有着更好的效果,且可以避免过拟合现象的发生。

5.在全连接层的前两层中使用了 Dropout 随机失活神经元操作,以减少过拟合。

Alexnet网络的结构

第一层卷积层

AlexNet网络详解(实现花的种类识别)_第1张图片

输入图片大小3x224x224,卷积核大小为11,padding[1,2],上面左面补一行0,下面右面补两行0,stride为4,96个卷积核,因为在两个gpu上进行卷积运算,两个卷积运算卷积核数量为48。输出大小为96x55x55,也就是两个48x55x55。

第二层池化层

AlexNet网络详解(实现花的种类识别)_第2张图片

输入为96x55x55,padding为0,按stride为2进行3 × 3的Max池化,输出为96x27x27。

第三层卷积层

AlexNet网络详解(实现花的种类识别)_第3张图片

输入为96x27x27,卷积核大小为5,padding[2,2],上下左右各补两行0,stride为1,512个卷积核,因为在两个gpu上进行卷积运算,两个卷积运算卷积核数量为128。输出大小为256x27x27,也就是两个128x27x27。

第四层池化层

AlexNet网络详解(实现花的种类识别)_第4张图片

输入为256x27x27,padding为0,按stride为2进行3 × 3的Max池化,输出为256x13x13。

第五层卷积层

AlexNet网络详解(实现花的种类识别)_第5张图片

输入为256x13x13,卷积核大小为3,padding[1,1],上下左右各补一行0,stride为1,384个卷积核,因为在两个gpu上进行卷积运算,两个卷积运算卷积核数量为192。输出大小为384x13x13,也就是两个192x13x13。

第六层卷积层

AlexNet网络详解(实现花的种类识别)_第6张图片

输入为384x13x13,卷积核大小为3,padding[1,1],上下左右各补一行0,stride为1,384个卷积核,因为在两个gpu上进行卷积运算,两个卷积运算卷积核数量为192。输出大小为384x13x13,也就是两个192x13x13。

第七层卷积层

AlexNet网络详解(实现花的种类识别)_第7张图片

输入为384x13x13,卷积核大小为3,padding[1,1],上下左右各补一行0,stride为1,256个卷积核,因为在两个gpu上进行卷积运算,两个卷积运算卷积核数量为128。输出大小为256x13x13,也就是两个128x13x13。

第八层池化层

AlexNet网络详解(实现花的种类识别)_第8张图片

输入为256x13x13,padding为0,按stride为2进行3 × 3的Max池化,输出为256x6x6。

最后三层全连接层

最后三层全连接层输出最后结果。

AlexNet学习实现花的种类识别

1.建立模型

这里我们直接取一半进行卷积运算,第一层卷积层卷积核数量为48,最后的输出数量为我们要辨别的花种类的数量。

import torch.nn as nn
import torch


class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, init_weights=False):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

2.训练模型

数据集我上传到百度网盘里,可自行下载解压到根目录下。

链接:https://pan.baidu.com/s/1F7qqVMlx6bGGugjNH2G9RA?pwd=i3p0 
提取码:i3p0 

AlexNet网络详解(实现花的种类识别)_第9张图片

我们可以选择gpu训练,如果没有空闲的gpu选择cpu。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

图像预处理加上随机截取和翻转,毕竟一张图片是蒲公英,总不能随机截取一下或者反转一下就不是蒲公英了吧,等于扩大了训练集数量。

"train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

获取花分类名称对应索引,遍历字典将val-key -> key-val,并转成json格式写入文件。

# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx  # 获取花分类名称对应索引
cla_dict = dict((val, key) for key, val in flower_list.items())  # 遍历字典将val-key -> key-val

# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)

得到:

AlexNet网络详解(实现花的种类识别)_第10张图片

完整训练代码如下:

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm

from model import AlexNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    train_dataset = datasets.ImageFolder(root='./train',
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx  # 获取花分类名称对应索引
    cla_dict = dict((val, key) for key, val in flower_list.items())  # 遍历字典将val-key -> key-val

    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=0)

    validate_dataset = datasets.ImageFolder(root='./val',
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=4, shuffle=False,
                                                  num_workers=0)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))
    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()
    #
    # def imshow(img):
    #     img = img / 2 + 0.5  # unnormalize
    #     npimg = img.numpy()
    #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #     plt.show()
    #
    # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
    # imshow(utils.make_grid(test_image))

    net = AlexNet(num_classes=5, init_weights=True)

    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    # pata = list(net.parameters())
    optimizer = optim.Adam(net.parameters(), lr=0.0002)

    epochs = 10
    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()  # dropout开启
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()  # dropout关闭
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

AlexNet网络详解(实现花的种类识别)_第11张图片

可以看出,最高识别率达到72%。

3.测试泛化能力

去网上找了几张图片,简直恐怖,我都不认识郁金香。

AlexNet网络详解(实现花的种类识别)_第12张图片

AlexNet网络详解(实现花的种类识别)_第13张图片

AlexNet网络详解(实现花的种类识别)_第14张图片

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import AlexNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # load image
    img_path = "3.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)

    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = AlexNet(num_classes=5).to(device)

    # load model weights
    weights_path = "./AlexNet.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path))

    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()

你可能感兴趣的:(深度学习,学习,深度学习,人工智能)