在2010年的ImageNet LSVRC-2010上,AlexNet在给包含有1000种类别的共120万张高分辨率图片的分类任务中,在测试集上的top-1和top-5错误率为37.5%和17.0%(top-5 错误率:即对一张图像预测5个类别,只要有一个和人工标注类别相同就算对,否则算错。同理top-1对一张图像只预测1个类别),在ImageNet LSVRC-2012的比赛中,取得了top-5错误率为15.3%的成绩,而第二名的成绩为26.2%,可见AlexNet在当时有多强大。
input 224*224*3
Conv(kernel_size=11*11, kernel_num=96, stride=4, padding=2)
output (224-11+2*2)/4+1=55 -> 55*55*96
Relu
Pool(kernel_size=3*3, stride=2)
output (55-3)/2+1=27 -> 27*27*96
Local Response Normalization(local_size=5)
output 27*27*96
input 27*27*96
Conv(kernel_size=5*5, kernel_num=256, stride=1, padding=2)
output (27-5+2*2)/1+1=27 -> 27*27*256
Relu
Pool(kernel_size=3*3, stride=2)
output (27-3)/2+1=13 -> 13*13*256
Local Response Normalization(local_size=5)
output 13*13*256
input 13*13*256
Conv(kernel_size=3*3, kernel_num=384, stride=1, padding=1)
output (13-3+2*1)/1+1=13 -> 13*13*384
Relu
output 13*13*256
input 13*13*256
Conv(kernel_size=3*3, kernel_num=384, stride=1, padding=1)
output (13-3+2*1)/1+1=13 -> 13*13*384
Relu
output 13*13*384
input 13*13*256
Conv(kernel_size=3*3, kernel_num=256, stride=1, padding=1)
output (13-3+2*1)/1+1=13 -> 13*13*256
Relu
Pool(kernel_size=3*3, stride=2)
output (13-3)/2+1=6 -> 6*6*256
input 6*6*256
Fc
Relu
Dropout
output 4096
input 4096
Fc
Relu
Dropout
output 4096
input 4096
Fc
Relu
Dropout
output 1000
Alexnet网络中各个层发挥的作用如下表所述:
model.py
import torch.nn as nn
import torch
class AlexNet(nn.Module):
def __init__(self, num_classes=1000, init_weights=False):
super(AlexNet, self).__init__()
self.features = nn.Sequential( # 打包
nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2), # input[3, 224, 224] output[48, 55, 55] 自动舍去小数点后
nn.ReLU(inplace=True), # inplace 可以载入更大模型
nn.MaxPool2d(kernel_size=3, stride=2), # output[48, 27, 27] kernel_num为原论文一半
nn.Conv2d(48, 128, kernel_size=5, padding=2), # output[128, 27, 27]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 13, 13]
nn.Conv2d(128, 192, kernel_size=3, padding=1), # output[192, 13, 13]
nn.ReLU(inplace=True),
nn.Conv2d(192, 192, kernel_size=3, padding=1), # output[192, 13, 13]
nn.ReLU(inplace=True),
nn.Conv2d(192, 128, kernel_size=3, padding=1), # output[128, 13, 13]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 6, 6]
)
self.classifier = nn.Sequential(
nn.Dropout(p=0.5),
# 全连接
nn.Linear(128 * 6 * 6, 2048),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5),
nn.Linear(2048, 2048),
nn.ReLU(inplace=True),
nn.Linear(2048, num_classes),
)
if init_weights:
self._initialize_weights()
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, start_dim=1) # 展平 或者view()
x = self.classifier(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') # 何教授方法
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01) # 正态分布赋值
nn.init.constant_(m.bias, 0)
从http://download.tensorflow.org/example_images/flower_photos.tgz下载数据集
执行下面代码,将数据集划分为训练集与验证集。
split_data.py
import os
from shutil import copy
import random
def mkfile(file):
if not os.path.exists(file):
os.makedirs(file)
file = 'flower_data/flower_photos'
flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla]
mkfile('flower_data/train')
for cla in flower_class:
mkfile('flower_data/train/'+cla)
mkfile('flower_data/val')
for cla in flower_class:
mkfile('flower_data/val/'+cla)
split_rate = 0.1
for cla in flower_class:
cla_path = file + '/' + cla + '/'
images = os.listdir(cla_path)
num = len(images)
eval_index = random.sample(images, k=int(num*split_rate))
for index, image in enumerate(images):
if image in eval_index:
image_path = cla_path + image
new_path = 'flower_data/val/' + cla
copy(image_path, new_path)
else:
image_path = cla_path + image
new_path = 'flower_data/train/' + cla
copy(image_path, new_path)
print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing bar
print()
print("processing done!")
训练模型 train.py
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time
# device : GPU or CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
# 数据转换
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
"val": transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224)
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
# data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path
data_root = os.getcwd()
image_path = data_root + "/flower_data/" # flower data set path
train_dataset = datasets.ImageFolder(root=image_path + "/train",
transform=data_transform["train"])
train_num = len(train_dataset)
# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=0)
validate_dataset = datasets.ImageFolder(root=image_path + "/val",
transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=True,
num_workers=0)
test_data_iter = iter(validate_loader)
test_image, test_label = test_data_iter.next()
# print(test_image[0].size(),type(test_image[0]))
# print(test_label[0],test_label[0].item(),type(test_label[0]))
# 显示图像,之前需把validate_loader中batch_size改为4
# def imshow(img):
# img = img / 2 + 0.5 # unnormalize
# npimg = img.numpy()
# plt.imshow(np.transpose(npimg, (1, 2, 0)))
# plt.show()
#
# print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
# imshow(utils.make_grid(test_image))
net = AlexNet(num_classes=5, init_weights=True)
net.to(device)
# 损失函数:这里用交叉熵
loss_function = nn.CrossEntropyLoss()
# 优化器 这里用Adam
optimizer = optim.Adam(net.parameters(), lr=0.0002)
# 训练参数保存路径
save_path = './AlexNet.pth'
# 训练过程中最高准确率
best_acc = 0.0
# 开始进行训练和测试,训练一轮,测试一轮
for epoch in range(10):
# train
net.train() # 训练过程中,使用之前定义网络中的dropout
running_loss = 0.0
t1 = time.perf_counter()
for step, data in enumerate(train_loader, start=0):
images, labels = data
optimizer.zero_grad()
outputs = net(images.to(device))
loss = loss_function(outputs, labels.to(device))
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
# print train process
rate = (step + 1) / len(train_loader)
a = "*" * int(rate * 50)
b = "." * int((1 - rate) * 50)
print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
print()
print(time.perf_counter()-t1)
# validate
net.eval() # 测试过程中不需要dropout,使用所有的神经元
acc = 0.0 # accumulate accurate number / epoch
with torch.no_grad():
for val_data in validate_loader:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
predict_y = torch.max(outputs, dim=1)[1]
acc += (predict_y == val_labels.to(device)).sum().item()
val_accurate = acc / val_num
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %
(epoch + 1, running_loss / step, val_accurate))
print('Finished Training')
使用向日葵图片进行测试 predict.py
import torch
from model import AlexNet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
data_transform = transforms.Compose(
[transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# load image
img = Image.open("./sunflower.jpg") # 验证太阳花
# img = Image.open("./roses.jpg") # 验证玫瑰花
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
# read class_indict
try:
json_file = open('./class_indices.json', 'r')
class_indict = json.load(json_file)
except Exception as e:
print(e)
exit(-1)
# create model
model = AlexNet(num_classes=5)
# load model weights
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img))
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()