CV(计算机视觉)领域四大类之图像分类一(AlexNet)(含论文和源码)

由于我个人学习图像分类的时候就是从lenet、alexnet开始学的,而lenet本身比较简单,因此就从alexnet开始写了,本系列会从alexnet开始写到最新的convNext、transformer等分类算法,如果后期有新的分类算法发表,我也会根据自己学习的情况持续更新。

代码git地址:xs-dl: 深度学习相关算法研究和实现 (gitee.com)(持续更新中)

1 alexnet网络结构

论文地址:ImageNet Classification with Deep Convolutional Neural Networks (neurips.cc)

CV(计算机视觉)领域四大类之图像分类一(AlexNet)(含论文和源码)_第1张图片

上图来自于alex大神在2012年发表的ImageNet Classification with Deep Convolutional Neural Networks论文,当年其凭借alexnet算法力压以往传统算法,一举摘得2012年图像分类大赛的冠军。alexnet网络具有6000万个参数和650,000个神经元的神经网络由五个卷积层组成,随后是最大池化层,三个全连接层以及最后的1000个softmax输出。

2 alexnet网络参数

layer_name kernel_size kernel_num stride padding input_size output_size
conv1 11*11 96 4 [1, 2] 224*224*3 55*55*96
max_pool1 3*3 / 2 / 55*55*96 27*27*96
conv2 5*5 256 1 [2, 2] 27*27*96 27*27*256
max_pool2 3*3 / 2 / 27*27*256 13*13*256
conv3 3*3 384 1 [1, 1] 13*13*256 13*13*384
conv4 3*3 384 1 [1, 1] 13*13*384 13*13*384
conv5 3*3 256 1 [1, 1] 13*13*384 13*13*256
max_pool3 3*3 0 2 0 13*13*256 6*6*256
fc1 4096 / / / 6*6*256 4096
fc2 4096 / / / 4096 4096
fc3 1000 / / / 4096 1000

上图的参数是根据原论文的数据和源代码中的参数反推出来的,大部分参数直接根据原论文的网络图即可获得,其中padding的计算公示为:

N = ((W - F + 2P)/ S ) +1

N: 输出端特征层大小

W:输入端特征层大小

F:卷积核大小

P:填充像素数

S:卷积核移动步距

以conv1为例计算:(P参数未知)

由N = ((W - F + 2P)/ S ) +1 得:55 = ((224 - 11 + 2P)/ 4)+1

2P=3 可得步距为:[1, 2]

3 pytarch实现

(1)数据准备make_data.py(花分类数据集)

import os

from PIL import Image

# 下载数据集地址
DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'

# 解压数据集的路径(自己定义即可)
flower_photos = "D:\\train_data\\flower_photos\\"
# 训练数据路径(自己定义即可)
base_url = "D:\\train_data\\"

for item in os.listdir(flower_photos):
    path_temp = flower_photos + item
    n = 0
    for name in os.listdir(path_temp):
        n += 1
        img = Image.open(path_temp + "\\" + name)
        # 转换通道
        img = img.convert("RGB")
        # 验证集(20%验证集,80%数据集,可自行调节)
        if n % 8 == 0:
            if not os.path.exists(base_url + "val\\" + item):
                os.makedirs(base_url + "val\\" + item)
            img.save(base_url + "val\\" + item + "\\" + name)
        else:
            if not os.path.exists(base_url + "train\\" + item):
                os.makedirs(base_url + "train\\" + item)
            img.save(base_url + "train\\" + item + "\\" + name)

(2)alexnet网络 model.py

import torch
import torch.nn as nn


class AlexNet(nn.Module):
    def __init__(self, num_classes=5):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                 
            nn.Conv2d(48, 128, kernel_size=5, padding=2),         
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                
            nn.Conv2d(128, 192, kernel_size=3, padding=1),        
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),        
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),       
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x

(3)训练脚本

import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim

from tqdm import tqdm
from model import AlexNet
from torchvision import transforms, datasets


def main():
    # 确定是否可以启动GPU训练
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # 设置训练集和验证集格变换规则
    train_form = transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    val_form = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # 训练集
    image_path = "D:\\train_data"
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), transform=train_form)

    # 由于我本机只有一个显卡,所以num_workers设置为0了
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)

    # 验证集
    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), transform=val_form)
    val_num = len(validate_dataset)

    validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=4, shuffle=False, num_workers=0)

    net = AlexNet(num_classes=5)

    net.to(device)
    loss_function = nn.CrossEntropyLoss()

    optimizer = optim.Adam(net.parameters(), lr=0.0002)

    epochs = 10
    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # 训练集
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)

        # 验证集
        net.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' % (epoch + 1, running_loss / train_steps, val_accurate))

        # 保存最优的模型(如果硬盘较大可以选择保存每个epoch)
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)


if __name__ == '__main__':
    main()

CV(计算机视觉)领域四大类之图像分类一(AlexNet)(含论文和源码)_第2张图片

(4)测试脚本

import os
import json
import torch

from PIL import Image
from model import AlexNet
from torchvision import transforms


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(),
                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    img = Image.open("./1.jpg")

    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = AlexNet(num_classes=5).to(device)

    # load model weights
    weights_path = "./AlexNet.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path))

    model.eval()
    with torch.no_grad():
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    print(print_res)


if __name__ == '__main__':
    main()

CV(计算机视觉)领域四大类之图像分类一(AlexNet)(含论文和源码)_第3张图片

4 alexnet算法总结

alexnet之所以可以力压传统的图片分类算法主要归结于以下几点:

(1)使用了非线性激活函数:ReLU

(2)添加 Dropout 防止过拟合,

(3)数据扩充(Data augmentation)变相增加数据样本,同样是防止过拟合

(4)多GPU并行训练,加速模型训练

(5)LRN局部归一化层的使用

你可能感兴趣的:(人工智能,pytorch,多分类,深度学习,卷积神经网络)