【论文阅读】AlexNet——ImageNet Classification with Deep Convolutional Neural Networks

文章目录

      • 论文阅读
      • 模型图示
      • 代码实现
        • train.sh
        • predict.sh
      • 实验结果

有哪里不对希望可以指正,大家相互学习交流

论文阅读

ImageNet Classification with Deep Convolutional Neural Networks
该论文在2012年发布,主要新点有:使用深度CNN、LRN(虽然后来被证无用)、重叠池化、RELU、Dropout

1、采用非线性激活函数RELU,非饱和神经元要比饱和神经元快,论文中做的实验,4层的网络快6倍

non-saturating neurons——非饱和神经元,也就是神经元对应的激活函数不会将结果压缩到特定值

saturating neurons = 被挤压(到一个特定的区间)过的值

nn.ReLU(inplace=True)其中inplace代表是否进行覆盖操作,降低内存使用量

2、交叉GPU并行化
并没有仔细去看这段
3、LRN初始化
局部响应归一化
【论文阅读】AlexNet——ImageNet Classification with Deep Convolutional Neural Networks_第1张图片
实现的结果是Feature map中同一位置的不同通道的高激活位置对其他位置的抑制——侧向抑制(类似顶端优势)。
aix,y代表第i个通道的(x,y位置的值),k是为了防止为0,n为近邻通道个数,a、b为常数。在AlexNet中使用的k=2,n=5,a=10^-4,b=0.75
4、重叠池化
防止过拟合,后来也被证没有太大用处
5、数据增强
对于图片进行保存标签的水平翻转(镜像)和随即裁剪,输入256* 256的图片,通过这两种方法可以获得2048个图片,(32* 32*2);变换RGB通道,使用PCA??这里还没学过看不太懂,之后进行补充
6、Dropout

打破联合依赖性,每个神经元需要对前一层所有神经元都有关系,而不是特定的联合

随机丢失记忆

7、带动量的SGD
【论文阅读】AlexNet——ImageNet Classification with Deep Convolutional Neural Networks_第2张图片
动量是为了防止模型在BP时候卡在局部最小值,weight decay类似于正则化

论文中模型结构:本文章使用的AlexNet的模型结构跟论文不一致
利用两个GPU进行计算,在第三个卷积层和全连接层时候讲两个GPU的数据进行通信,其他的只是在不同的GPU上进行计算
【论文阅读】AlexNet——ImageNet Classification with Deep Convolutional Neural Networks_第3张图片

模型图示

【论文阅读】AlexNet——ImageNet Classification with Deep Convolutional Neural Networks_第4张图片

代码实现

model.py

from typing import Callable, Optional, OrderedDict
import torch
import torch.nn as nn


class BasicConv(nn.Sequential):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 stride: int = 1,
                 padding: int = 2,
                 activation_layer: Optional[Callable[..., nn.Module]] = None):
        if activation_layer is None:
            activation_layer = nn.ReLU(inplace=True)
        super(BasicConv, self).__init__(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                      kernel_size=kernel_size, stride=stride, padding=padding),
            activation_layer
        )


class AlexNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()

        layers = OrderedDict()
        activation_layer = nn.ReLU(inplace=True)
        Maxpool_layer = nn.MaxPool2d(kernel_size=3, stride=2)
        layers.update({'conv1': BasicConv(3, 48, 11, 4, 2, activation_layer)})
        layers.update({'max1': Maxpool_layer})
        layers.update({'conv2': BasicConv(48, 128, 5, 1, 2, activation_layer)})
        layers.update({'max2': Maxpool_layer})
        layers.update(
            {'conv3': BasicConv(128, 192, 3, 1, 1, activation_layer)})
        layers.update(
            {'conv4': BasicConv(192, 192, 3, 1, 1, activation_layer)})
        layers.update(
            {'conv5': BasicConv(192, 128, 3, 1, 1, activation_layer)})
        layers.update({'max3': Maxpool_layer})

        self.features = nn.Sequential(layers)

        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128 * 6 * 6, 2048),
            activation_layer,
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            activation_layer,
            nn.Linear(2048, num_classes),
        )

        self.apply(_init_weights)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x


def _init_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(
            m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
            # nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, 0, 0.01)
        nn.init.zeros_(m.bias)

train.py

import argparse
import os
import json
import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm


from model import AlexNet as create_model


# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}

def main(args):     
    print(args)

    device = torch.device(args.device if torch.cuda.is_available() else "cpu")    
    print("using {} device.".format(device))

    batch_size=args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}


    if args.num_classes==10:

        train_dataset = datasets.CIFAR10(root=args.data_path, train=True,
                                      download=True, transform=data_transform['train'])
        val_dataset = datasets.CIFAR10(root=args.data_path, train=False,
                                    download=True, transform=data_transform['val'])
    else:
        train_dataset = datasets.CIFAR100(root=args.data_path, train=True,
                                      download=True, transform=data_transform['train'])
        val_dataset = datasets.CIFAR100(root=args.data_path, train=False,
                                    download=True, transform=data_transform['val'])
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=nw)
    
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=nw)

    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)


    val_num = len(val_dataset)
    train_num = len(train_dataset)
    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    net = create_model(num_classes=args.num_classes)
    net.to(device)

    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=args.lr)

    epochs = args.epochs
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader)
        for data in train_bar:
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        net.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(val_loader)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                
                predict_y = torch.max(outputs, dim=1)[1]
                acc += (predict_y == val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), args.model_path)

    print('Finished Training')


if __name__ == '__main__':
    parser=argparse.ArgumentParser()

    parser.add_argument('--num_classes',type=int,default=100)
    parser.add_argument('--epochs',type=int,default=20)
    parser.add_argument('--batch_size',type=int,default=48)
    parser.add_argument('--lr',type=float,default=0.001)

    parser.add_argument('--data_path', type=str,
                        default="D:/dataset/cifar100")
    parser.add_argument('--device', default='cuda:0')
    parser.add_argument('--model_path',default='./AlexNet_cifar100.pth')
    opt = parser.parse_args()

    main(opt)

predict.py

import argparse
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import AlexNet


def main(args):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    img_path = args.data_path
    img = Image.open(img_path).convert('RGB')

    plt.imshow(img)
    img = data_transform(img)

    # 添加维度用于输入模型中
    img = torch.unsqueeze(img, dim=0)

    json_path = args.json_path
    json_file = open(json_path, "r")
    class_indict = json.load(json_file)

    model = AlexNet(num_classes=args.num_classes).to(device)
    weights_path = args.model_path
    model.load_state_dict(torch.load(weights_path))

    model.eval()
    with torch.no_grad():
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    print(print_res)
    plt.show()


if __name__ == '__main__':

    parser = argparse.ArgumentParser()

    parser.add_argument('--num_classes', type=int, default=10)
    parser.add_argument('--data_path', type=str,
                        default="D:/dataset/apple.jpg")
    parser.add_argument('--json_path', default='./class_indices.json')
    parser.add_argument('--model_path', default='./AlexNet_cifar10.pth')
    opt = parser.parse_args()

    main(opt)

train.sh

[ -z "${lr}"] && lr="1e-3"
[ -z "${epochs}" ] && epochs="8"
[ -z "${batch_size}" ] && batch_size="64"  
[ -z "${num_classes}"] && num_classes="10"

[ -z "${model_path}"] &&  model_path="./AlexNet_cifar10.pth"
[ -z "${data_path}"] && data_path="D:/dataset/cifar10"

echo -e "\n\n"
echo "=====================================ARGS======================================"
echo "arg: $0"
echo "lr ${lr}"
echo "epoch: ${epochs}"
echo "batch_size: ${batch_size}"
echo "num_classes: ${num_classes}"

echo "model_path: ${model_path}"
echo "data_path: ${data_path}"
echo "==============================================================================="


python ./train.py --batch_size $batch_size --epochs $epochs --lr $lr --num_classes $num_classes    \
     --model_path $model_path --data_path $data_path



sleep 2m

predict.sh

[ -z "${json_path}"] && json_path="./class_indices.json"
[ -z "${model_path}"] &&  model_path="./AlexNet_cifar100.pth"
[ -z "${data_path}"] && data_path="D:/dataset/apple.jpg"
[ -z "${num_classes}"] && num_classes="100"

echo -e "\n\n"
echo "=====================================ARGS======================================"
echo "json_path: ${json_path}"

echo "num_classes: ${num_classes}"
echo "model_path: ${model_path}"
echo "data_path: ${data_path}"
echo "==============================================================================="


python ./predict.py --json_path $json_path  --model_path $model_path --data_path $data_path  --num_classes $num_classes



sleep 2m

实验结果

使用的数据集为P导师的五分类花数据集
【论文阅读】AlexNet——ImageNet Classification with Deep Convolutional Neural Networks_第5张图片
【论文阅读】AlexNet——ImageNet Classification with Deep Convolutional Neural Networks_第6张图片

你可能感兴趣的:(学习笔记,图像分类,教程,神经网络,深度学习,python)