ResNet
提出了一种残差学习框架来解决网络退化问题,从而训练更深的网络。这种框架可以结合已有的各种网络结构,充分发挥二者的优势。
ResNet
以三种方式挑战了传统的神经网络架构:
ResNet
通过引入跳跃连接来绕过残差层,这允许数据直接流向任何后续层。
这与传统的、顺序的pipeline
形成鲜明对比:传统的架构中,网络依次处理低级feature
到高级feature
。
ResNet
的层数非常深,高达1202层。而ALexNet
这样的架构,网络层数要小两个量级。
通过实验发现,训练好的 ResNet
中去掉单个层并不会影响其预测性能。而训练好的AlexNet
等网络中,移除层会导致预测性能损失。
在ImageNet
分类数据集中,拥有152层的残差网络,以3.75% top-5
的错误率获得了ILSVRC 2015
分类比赛的冠军。
很多证据表明:残差学习是通用的,不仅可以应用于视觉问题,也可应用于非视觉问题。
论文地址: https://arxiv.org/pdf/1512.03385.pdf
卷积神经网络领域的两次技术爆炸,第一次是AlexNet,第二次就是ResNet了。
1、理论上来讲网络深度越深越好。网络越深,提取的图片特征越多越丰富,但随之会带来很多的问题(通过Batch Normalization
在很大程度上解决),比如过拟合或者计算量爆炸、梯度消失、梯度爆炸等,导致网络在一定深度下就达到了局部最优解。
2、ResNet
论文作者发现:随着网络的深度的增加,准确率达到饱和之后迅速下降,而这种下降不是由过拟合引起的。这称作网络退化问题
。如果更深的网络训练误差更大,则说明是由于优化算法引起的:越深的网络,求解优化问题越难。如下所示:更深的网络导致更高的训练误差和测试误差。
通过多个非线性层来近似横等映射可能是困难的
。每个附加层都应该更容易地包含原始函数作为其元素之⼀
。1、假设需要学习的是映射 y = H(x),残差块使用堆叠的非线性层拟合残差:y = F(x,W) + x 。
其中:
+
:通过快捷连接
逐个元素相加来执行。快捷连接
指的是那些跳过一层或者更多层的连接。
2、残差映射易于捕捉恒等映射的细微波动
。比如5正常映射为5.1,加入残差后变成 5+0.1。此时输入变成5.2,对于没有残差结构的结果,影响仅为0.1/5.1 = 2%。而对于残差结构,变成 5+0.2 , 由0.1变成了0.2 影响为100%。
3、残差映射 H ( x ) = F ( x ) + x ,在反向传播的时候就变成了 H ′ ( x ) = F ′ ( x ) + 1,这里的加1也可以保证梯度消失现象
4、作者也证明了退化问题在任何数据集上都普遍存在。在imagenet上拿到冠军之后,迁移学习用到了coco同样拿到了好几个赛道的冠军,说明残差结构是普适的。最后又和VGG比了一下,比VGG深了8倍,计算复杂性却还比VGG小 。
层数可变:论文中的实验包含有两层堆叠、三层堆叠,实际任务中也可以包含更多层的堆叠。
如果F
只有一层,则残差块退化线性层:y = Wx + x
。此时对网络并没有什么提升。
连接形式可变:不仅可用于全连接层,可也用于卷积层。此时F代表多个卷积层的堆叠,而最终的逐元素加法+
在两个feature map
上逐通道进行。
此时
x
也是一个feature map
,而不再是一个向量。
学习残差F(x,W)比学习原始映射H(x)要更容易。
1、当原始映射H
就是一个恒等映射时, 就是一个F
零映射。此时求解器只需要简单的将堆叠的非线性连接的权重推向零即可。
实际任务中原始映射 H
可能不是一个恒等映射:
H
更偏向于恒等映射(而不是更偏向于非恒等映射),则F
就是关于恒等映射的抖动,会更容易学习。H
更偏向于零映射,那么学习 本身要更容易。但是在实际应用中,零映射非常少见,因为它会导致输出全为0。2、如果原始映射H
是一个非恒等映射,则可以考虑对残差模块使用缩放因子。如Inception-Resnet
中:在残差模块与快捷连接叠加之前,对残差进行缩放。注意:ResNet
作者在随后的论文中指出:不应该对恒等映射进行缩放。
3、可以通过观察残差 F
的输出来判断:如果F
的输出均为0附近的、较小的数,则说明原始映射H
更偏向于恒等映射;否则,说明原始映射H
更偏向于非横等映射。
from torch import nn
from torch.nn import functional as F
import torch
'''
⼀种是当use_1x1conv=False时,应⽤ReLU⾮线性函数之前,将输⼊添加到输出。
另⼀种是当use_1x1conv=True时,添加通过1 × 1卷积调整通道和分辨率
ResNet沿⽤了VGG完整的3 × 3卷积层设计。
残差块⾥⾸先有2个有相同输出通道数的3 × 3卷积层。
每个卷积层后接⼀个批量规范化层和ReLU激活函数。
然后我们通过跨层数据通路,跳过这2个卷积运算,将输⼊直接加在最后的ReLU激活函数前。
这样的设计要求2个卷积层的输出与输⼊形状⼀样,从⽽使它们可以相加。
如果想改变通道数,就需要引⼊⼀个额外的1 × 1卷积层来将输⼊变换成需要的形状后再做相加运算。
'''
class Residual(nn.Module):
def __init__(self,input_channels, num_channels,use_1x1conv=False, strides=1):
super().__init__()
self.conv1 = nn.Conv2d(input_channels, num_channels,kernel_size=3, padding=1, stride=strides)
self.conv2 = nn.Conv2d(num_channels, num_channels,kernel_size=3, padding=1)
if use_1x1conv:
self.conv3 = nn.Conv2d(input_channels, num_channels,
kernel_size=1, stride=strides)
else:
self.conv3 = None
self.bn1 = nn.BatchNorm2d(num_channels)
self.bn2 = nn.BatchNorm2d(num_channels)
def forward(self, X):
Y = F.relu(self.bn1(self.conv1(X)))
Y = self.bn2(self.conv2(Y))
if self.conv3:
X = self.conv3(X)
Y += X
return F.relu(Y)
if __name__ == '__main__':
blk = Residual(3, 3)
X = torch.rand(4, 3, 6, 6)
Y = blk(X)
print(Y.shape) # 输⼊和输出形状⼀致 torch.Size([4, 3, 6, 6])
blk = Residual(3, 6, use_1x1conv=True, strides=2)
Y = blk(X)
print(Y.shape) # 在增加输出通道数的同时,减半输出的高和宽 torch.Size([4, 6, 3, 3])
plain
网络plain
网络:一些简单网络结构的叠加,如下图所示。图中给出了四种plain
网络,它们的区别主要是网络深度不同。其中,输入图片尺寸 224x224 。
ResNet
简单的在plain
网络上添加快捷连接来实现。
FLOPs
:floating point operations
的缩写,意思是浮点运算量,用于衡量算法/模型的复杂度。
FLOPS
:floating point per second
的缩写,意思是每秒浮点运算次数,用于衡量计算速度。
相对于输入的feature map
,残差块的输出feature map
尺寸可能会发生变化:
输出 feature map
的通道数增加,此时需要扩充快捷连接的输出feature map
。否则快捷连接的输出 feature map
无法和残差块的feature map
累加。
有两种扩充方式:
1x1
卷积来扩充维度。输出 feature map
的尺寸减半。此时需要对快捷连接执行步长为 2 的池化/卷积:如果快捷连接已经采用 1x1
卷积,则该卷积步长为2 ;否则采用步长为 2 的最大池化 。
VGG-19 | 34层 plain 网络 | Resnet-34 | |
---|---|---|---|
计算复杂度(FLOPs) | 19.6 billion | 3.5 billion | 3.6 billion |
在ImageNet
验证集上执行10-crop
测试的结果。
A
类模型:快捷连接中,所有需要扩充的维度的填充 0 。B
类模型:快捷连接中,所有需要扩充的维度通过1x1
卷积来扩充。C
类模型:所有快捷连接都通过1x1
卷积来执行线性变换。C
优于B
,B
优于A
。但是 C
引入更多的参数,相对于这种微弱的提升,性价比较低。所以后续的ResNet
均采用 B
类模型。
模型 | top-1 误差率 | top-5 误差率 |
---|---|---|
VGG-16 | 28.07% | 9.33% |
GoogleNet | - | 9.15% |
PReLU-net | 24.27% | 7.38% |
plain-34 | 28.54% | 10.02% |
ResNet-34 A | 25.03% | 7.76% |
ResNet-34 B | 24.52% | 7.46% |
ResNet-34 C | 24.19% | 7.40% |
ResNet-50 | 22.85% | 6.71% |
ResNet-101 | 21.75% | 6.05% |
ResNet-152 | 21.43% | 5.71% |
import torch.nn as nn
import torch
from _06_Residual import Residual
class ResNet18(nn.Module):
def __init__(self):
super(ResNet18, self).__init__()
self.model = self.get_net()
def forward(self, X):
X = self.model(X)
return X
def get_net(self):
'''
ResNet的前两层跟GoogLeNet中的⼀样:
在输出通道数为64、步幅为2的7 × 7卷积层后,接步幅为2的3 × 3的最⼤汇聚层。
不同之处在于ResNet每个卷积层后增加了批量规范化层。
'''
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(64), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
'''
GoogLeNet在后⾯接了4个由Inception块组成的模块。
ResNet则使⽤4个由残差块组成的模块,每个模块使⽤若⼲个同样输出通道数的残差块。
第⼀个模块的通道数同输⼊通道数⼀致。由于之前已经使⽤了步幅为2的最⼤汇聚层,所以⽆须减⼩⾼和宽。
之后的每个模块在第⼀个残差块⾥将上⼀个模块的通道数翻倍,并将⾼和宽减半。
'''
b2 = nn.Sequential(*self.resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*self.resnet_block(64, 128, 2))
b4 = nn.Sequential(*self.resnet_block(128, 256, 2))
b5 = nn.Sequential(*self.resnet_block(256, 512, 2))
net = nn.Sequential(b1, b2, b3, b4, b5,
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(), nn.Linear(512, 10))
return net
def resnet_block(self, input_channels, num_channels, num_residuals, first_block=False):
blk = []
for i in range(num_residuals):
if i == 0 and not first_block:
blk.append(Residual(input_channels, num_channels, use_1x1conv=True, strides=2))
else:
blk.append(Residual(num_channels, num_channels))
return blk
if __name__ == '__main__':
net = ResNet18()
X = torch.rand(size=(1, 1, 224, 224), dtype=torch.float32)
for layer in net.model:
X = layer(X)
print(layer.__class__.__name__, 'output shape:', X.shape)
Sequential output shape: torch.Size([1, 64, 56, 56])
Sequential output shape: torch.Size([1, 64, 56, 56])
Sequential output shape: torch.Size([1, 128, 28, 28])
Sequential output shape: torch.Size([1, 256, 14, 14])
Sequential output shape: torch.Size([1, 512, 7, 7])
AdaptiveAvgPool2d output shape: torch.Size([1, 512, 1, 1])
Flatten output shape: torch.Size([1, 512])
Linear output shape: torch.Size([1, 10])
如1.2.4及1.3.3代码所示。
其他所有的函数,与经典神经网络(1)LeNet及其在Fashion-MNIST数据集上的应用完全一致。
batch_size = 256
# 为了使Fashion-MNIST上的训练短⼩精悍,将输⼊的⾼和宽从224降到96,简化计算
train_iter,test_iter = get_mnist_data(batch_size,resize=96)
from _06_ResNet18 import ResNet18
# 初始化模型
net = ResNet18()
lr, num_epochs = 0.05, 10
train_ch(net, train_iter, test_iter, num_epochs, lr, try_gpu())
import requests
import urllib3
urllib3.disable_warnings()
import time
import os
import random
import pandas as pd
import shutil
import matplotlib.pyplot as plt
%matplotlib inline
import cv2
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
# 忽略烦人的红色提示
import warnings
warnings.filterwarnings("ignore")
# 进度条库
from tqdm import tqdm
# http请求参数
cookies = {
'BDqhfp': '%E7%8B%97%E7%8B%97%26%26NaN-1undefined%26%2618880%26%2621',
'BIDUPSID': '06338E0BE23C6ADB52165ACEB972355B',
'PSTM': '1646905430',
'BAIDUID': '104BD58A7C408DABABCAC9E0A1B184B4:FG=1',
'BDORZ': 'B490B5EBF6F3CD402E515D22BCDA1598',
'H_PS_PSSID': '35836_35105_31254_36024_36005_34584_36142_36120_36032_35993_35984_35319_26350_35723_22160_36061',
'BDSFRCVID': '8--OJexroG0xMovDbuOS5T78igKKHJQTDYLtOwXPsp3LGJLVgaSTEG0PtjcEHMA-2ZlgogKK02OTH6KF_2uxOjjg8UtVJeC6EG0Ptf8g0M5',
'H_BDCLCKID_SF': 'tJPqoKtbtDI3fP36qR3KhPt8Kpby2D62aKDs2nopBhcqEIL4QTQM5p5yQ2c7LUvtynT2KJnz3Po8MUbSj4QoDjFjXJ7RJRJbK6vwKJ5s5h5nhMJSb67JDMP0-4F8exry523ioIovQpn0MhQ3DRoWXPIqbN7P-p5Z5mAqKl0MLPbtbb0xXj_0D6bBjHujtT_s2TTKLPK8fCnBDP59MDTjhPrMypomWMT-0bFH_-5L-l5js56SbU5hW5LSQxQ3QhLDQNn7_JjOX-0bVIj6Wl_-etP3yarQhxQxtNRdXInjtpvhHR38MpbobUPUDa59LUvEJgcdot5yBbc8eIna5hjkbfJBQttjQn3hfIkj0DKLtD8bMC-RDjt35n-Wqxobbtof-KOhLTrJaDkWsx7Oy4oTj6DD5lrG0P6RHmb8ht59JROPSU7mhqb_3MvB-fnEbf7r-2TP_R6GBPQtqMbIQft20-DIeMtjBMJaJRCqWR7jWhk2hl72ybCMQlRX5q79atTMfNTJ-qcH0KQpsIJM5-DWbT8EjHCet5DJJn4j_Dv5b-0aKRcY-tT5M-Lf5eT22-usy6Qd2hcH0KLKDh6gb4PhQKuZ5qutLTb4QTbqWKJcKfb1MRjvMPnF-tKZDb-JXtr92nuDal5TtUthSDnTDMRhXfIL04nyKMnitnr9-pnLJpQrh459XP68bTkA5bjZKxtq3mkjbPbDfn02eCKuj6tWj6j0DNRabK6aKC5bL6rJabC3b5CzXU6q2bDeQN3OW4Rq3Irt2M8aQI0WjJ3oyU7k0q0vWtvJWbbvLT7johRTWqR4enjb3MonDh83Mxb4BUrCHRrzWn3O5hvvhKoO3MA-yUKmDloOW-TB5bbPLUQF5l8-sq0x0bOte-bQXH_E5bj2qRCqVIKa3f',
'BDSFRCVID_BFESS': '8--OJexroG0xMovDbuOS5T78igKKHJQTDYLtOwXPsp3LGJLVgaSTEG0PtjcEHMA-2ZlgogKK02OTH6KF_2uxOjjg8UtVJeC6EG0Ptf8g0M5',
'H_BDCLCKID_SF_BFESS': 'tJPqoKtbtDI3fP36qR3KhPt8Kpby2D62aKDs2nopBhcqEIL4QTQM5p5yQ2c7LUvtynT2KJnz3Po8MUbSj4QoDjFjXJ7RJRJbK6vwKJ5s5h5nhMJSb67JDMP0-4F8exry523ioIovQpn0MhQ3DRoWXPIqbN7P-p5Z5mAqKl0MLPbtbb0xXj_0D6bBjHujtT_s2TTKLPK8fCnBDP59MDTjhPrMypomWMT-0bFH_-5L-l5js56SbU5hW5LSQxQ3QhLDQNn7_JjOX-0bVIj6Wl_-etP3yarQhxQxtNRdXInjtpvhHR38MpbobUPUDa59LUvEJgcdot5yBbc8eIna5hjkbfJBQttjQn3hfIkj0DKLtD8bMC-RDjt35n-Wqxobbtof-KOhLTrJaDkWsx7Oy4oTj6DD5lrG0P6RHmb8ht59JROPSU7mhqb_3MvB-fnEbf7r-2TP_R6GBPQtqMbIQft20-DIeMtjBMJaJRCqWR7jWhk2hl72ybCMQlRX5q79atTMfNTJ-qcH0KQpsIJM5-DWbT8EjHCet5DJJn4j_Dv5b-0aKRcY-tT5M-Lf5eT22-usy6Qd2hcH0KLKDh6gb4PhQKuZ5qutLTb4QTbqWKJcKfb1MRjvMPnF-tKZDb-JXtr92nuDal5TtUthSDnTDMRhXfIL04nyKMnitnr9-pnLJpQrh459XP68bTkA5bjZKxtq3mkjbPbDfn02eCKuj6tWj6j0DNRabK6aKC5bL6rJabC3b5CzXU6q2bDeQN3OW4Rq3Irt2M8aQI0WjJ3oyU7k0q0vWtvJWbbvLT7johRTWqR4enjb3MonDh83Mxb4BUrCHRrzWn3O5hvvhKoO3MA-yUKmDloOW-TB5bbPLUQF5l8-sq0x0bOte-bQXH_E5bj2qRCqVIKa3f',
'indexPageSugList': '%5B%22%E7%8B%97%E7%8B%97%22%5D',
'cleanHistoryStatus': '0',
'BAIDUID_BFESS': '104BD58A7C408DABABCAC9E0A1B184B4:FG=1',
'BDRCVFR[dG2JNJb_ajR]': 'mk3SLVN4HKm',
'BDRCVFR[-pGxjrCMryR]': 'mk3SLVN4HKm',
'ab_sr': '1.0.1_Y2YxZDkwMWZkMmY2MzA4MGU0OTNhMzVlNTcwMmM2MWE4YWU4OTc1ZjZmZDM2N2RjYmVkMzFiY2NjNWM4Nzk4NzBlZTliYWU0ZTAyODkzNDA3YzNiMTVjMTllMzQ0MGJlZjAwYzk5MDdjNWM0MzJmMDdhOWNhYTZhMjIwODc5MDMxN2QyMmE1YTFmN2QyY2M1M2VmZDkzMjMyOThiYmNhZA==',
'delPer': '0',
'PSINO': '2',
'BA_HECTOR': '8h24a024042g05alup1h3g0aq0q'
}
headers = {
'Connection': 'keep-alive',
'sec-ch-ua': '" Not;A Brand";v="99", "Google Chrome";v="97", "Chromium";v="97"',
'Accept': 'text/plain, */*; q=0.01',
'X-Requested-With': 'XMLHttpRequest',
'sec-ch-ua-mobile': '?0',
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.99 Safari/537.36',
'sec-ch-ua-platform': '"macOS"',
'Sec-Fetch-Site': 'same-origin',
'Sec-Fetch-Mode': 'cors',
'Sec-Fetch-Dest': 'empty',
'Referer': 'https://image.baidu.com/search/index?tn=baiduimage&ipn=r&ct=201326592&cl=2&lm=-1&st=-1&fm=result&fr=&sf=1&fmq=1647837998851_R&pv=&ic=&nc=1&z=&hd=&latest=©right=&se=1&showtab=0&fb=0&width=&height=&face=0&istype=2&dyTabStr=MCwzLDIsNiwxLDUsNCw4LDcsOQ%3D%3D&ie=utf-8&sid=&word=%E7%8B%97%E7%8B%97',
'Accept-Language': 'zh-CN,zh;q=0.9'
}
def download_single_class(file_path, keyword, DOWNLOAD_NUM=100):
if not os.path.exists(file_path + "/dataset"):
os.makedirs(file_path + "/dataset")
print(f'新建{file_path}/dataset文件夹')
if not os.path.exists(file_path + "/dataset/" + keyword):
os.makedirs(file_path + "/dataset/"+ keyword)
print('新建文件夹:{}/dataset/{}'.format(file_path, keyword))
else:
print('文件夹:{}/dataset/{}已经存在,之后将爬取的图片保存到该文件夹中'.format(file_path, keyword))
count = 1
with tqdm(total=DOWNLOAD_NUM, position=0, leave=True) as pbar:
# 爬取第几张
num = 1
# 是否继续爬取
FLAG = True
while FLAG:
page = 30 * count
params = (
('tn', 'resultjson_com'),
('logid', '12508239107856075440'),
('ipn', 'rj'),
('ct', '201326592'),
('is', ''),
('fp', 'result'),
('fr', ''),
('word', f'{keyword}'),
('queryWord', f'{keyword}'),
('cl', '2'),
('lm', '-1'),
('ie', 'utf-8'),
('oe', 'utf-8'),
('adpicid', ''),
('st', '-1'),
('z', ''),
('ic', ''),
('hd', ''),
('latest', ''),
('copyright', ''),
('s', ''),
('se', ''),
('tab', ''),
('width', ''),
('height', ''),
('face', '0'),
('istype', '2'),
('qc', ''),
('nc', '1'),
('expermode', ''),
('nojc', ''),
('isAsync', ''),
('pn', f'{page}'),
('rn', '30'),
('gsm', '1e'),
('1647838001666', ''),
)
response = requests.get('https://image.baidu.com/search/acjson', headers=headers, params=params,
cookies=cookies)
if response.status_code == 200:
try:
json_data = response.json().get("data")
if json_data:
for x in json_data:
type = x.get("type")
if type not in ["gif"]:
img = x.get("thumbURL")
fromPageTitleEnc = x.get("fromPageTitleEnc")
try:
resp = requests.get(url=img, verify=False)
time.sleep(1)
# print(f"链接 {img}")
# 保存文件名
# file_save_path = f'dataset/{keyword}/{num}-{fromPageTitleEnc}.{type}'
file_save_path = file_path + f'/dataset/{keyword}/{num}.{type}'
with open(file_save_path, 'wb') as f:
f.write(resp.content)
f.flush()
# print('第 {} 张图像 {} 爬取完成'.format(num, fromPageTitleEnc))
num += 1
pbar.update(1) # 进度条更新
# 爬取数量达到要求
if num > DOWNLOAD_NUM:
FLAG = False
print('{} 张图像爬取完毕'.format(num - 1))
break
except Exception:
pass
except:
pass
else:
break
count += 1
# 测试爬取香蕉
file_path = 'D:\python\kaggle\pictures_classfication_data'
download_single_class(file_path,'香蕉', DOWNLOAD_NUM=2)
# 爬取多类水果
class_list = ['苹果','梨','葡萄','火龙果','大枣',
'柑橘','柚子','桃','杏','西瓜',
'荔枝','甘蔗','柿子','羊角蜜','香蕉',
'菠萝','芒果','哈密瓜','石榴','椰子'
]
for class_name in class_list:
download_single_class(file_path,class_name)
新建文件夹:D:\python\kaggle\pictures_classfication_data/dataset/苹果
100%|██████████| 100/100 [03:17<00:00, 1.98s/it]
100 张图像爬取完毕
新建文件夹:D:\python\kaggle\pictures_classfication_data/dataset/椰子
100%|██████████| 100/100 [02:57<00:00, 1.77s/it]
100 张图像爬取完毕
file_path = file_path + '/dataset'
classes = os.listdir(file_path)
# 创建 train 文件夹
os.mkdir(os.path.join(file_path, 'train'))
# 创建 test 文件夹
os.mkdir(os.path.join(file_path, 'val'))
# 在 train 和 test 文件夹中创建各类别子文件夹
for fruit in classes:
os.mkdir(os.path.join(file_path, 'train', fruit))
os.mkdir(os.path.join(file_path, 'val', fruit))
test_frac = 0.2 # 测试集比例
random.seed(123) # 随机数种子,便于复现
df = pd.DataFrame()
print('{:^18} {:^18} {:^18}'.format('类别', '训练集数据个数', '测试集数据个数'))
for fruit in classes: # 遍历每个类别
# 读取该类别的所有图像文件名
old_dir = os.path.join(file_path, fruit)
images_filename = os.listdir(old_dir)
random.shuffle(images_filename) # 随机打乱
# 划分训练集和测试集
testset_numer = int(len(images_filename) * test_frac) # 测试集图像个数
testset_images = images_filename[:testset_numer] # 获取拟移动至 test 目录的测试集图像文件名
trainset_images = images_filename[testset_numer:] # 获取拟移动至 train 目录的训练集图像文件名
# 移动图像至 test 目录
for image in testset_images:
old_img_path = os.path.join(file_path, fruit, image) # 获取原始文件路径
new_test_path = os.path.join(file_path, 'val', fruit, image) # 获取 test 目录的新文件路径
shutil.move(old_img_path, new_test_path) # 移动文件
# 移动图像至 train 目录
for image in trainset_images:
old_img_path = os.path.join(file_path, fruit, image) # 获取原始文件路径
new_train_path = os.path.join(file_path, 'train', fruit, image) # 获取 train 目录的新文件路径
shutil.move(old_img_path, new_train_path) # 移动文件
# 删除旧文件夹
assert len(os.listdir(old_dir)) == 0 # 确保旧文件夹中的所有图像都被移动走
shutil.rmtree(old_dir) # 删除文件夹
# 输出每一类别的数据个数
print('{:^18} {:^18} {:^18}'.format(fruit, len(trainset_images), len(testset_images)))
# 保存到表格中
df = df.append({'class':fruit, 'trainset':len(trainset_images), 'testset':len(testset_images)}, ignore_index=True)
类别 训练集数据个数 测试集数据个数
# 数据集各类别数量统计表格,导出为 csv 文件
df['total'] = df['trainset'] + df['testset']
df.head()
df.to_csv('数据量统计.csv', index=False)
class | testset | trainset | total | |
---|---|---|---|---|
0 | 哈密瓜 | 20.0 | 80.0 | 100.0 |
1 | 大枣 | 20.0 | 80.0 | 100.0 |
2 | 杏 | 20.0 | 80.0 | 100.0 |
3 | 柑橘 | 20.0 | 80.0 | 100.0 |
4 | 柚子 | 20.0 | 80.0 | 100.0 |
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
figsize = (num_cols * scale, num_rows * scale)
_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
#读取图像,解决imread不能读取中文路径路径的问题
def cv_imread(file_path):
cv_img = cv2.imdecode(np.fromfile(file_path,dtype=np.uint8),-1)
return cv_img
# 读取训练集【西瓜】文件夹所有的图像
folder_path = os.path.join(file_path ,'train' , '西瓜')
images = []
for each_img in os.listdir(folder_path):
img_path = os.path.join(folder_path, each_img)
img_bgr = cv_imread(img_path)
img_rgb = cv2.cvtColor(img_bgr,cv2.COLOR_BGR2RGB)
images.append(img_rgb)
show_images([images[i] for i in range(32)],num_rows=4, num_cols=8, scale=1.0)
'''
将dataset水果分类打成zip压缩包,上传的linux机器上,用GPU训练
Linux下的默认编码是UTF8,Windows下生成的zip文件中的编码是GBK/GB2312等.zip文件
在Linux下解压时出现乱码问题.执行一下命令:
unzip -O GB18030 dataset.zip
'''
file_path = '/root/autodl-fs/data/fruit20/dataset'
def try_gpu(i=0):
if torch.cuda.device_count() >= i + 1:
return torch.device(f'cuda:{i}')
return torch.device('cpu')
'''
1、图像预处理
'''
from torchvision import transforms
# 训练集图像预处理:缩放裁剪、图像增强、转 Tensor、归一化
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
'''
2、载入水果图像分类数据集
'''
train_path = os.path.join(file_path, 'train')
test_path = os.path.join(file_path, 'val')
print('训练集路径', train_path)
print('测试集路径', test_path)
from torchvision import datasets
# 载入训练集
train_dataset = datasets.ImageFolder(train_path, train_transform)
# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)
print('训练集图像数量', len(train_dataset))
print('类别个数', len(train_dataset.classes))
print('各类别名称', train_dataset.classes)
print('测试集图像数量', len(test_dataset))
print('类别个数', len(test_dataset.classes))
print('各类别名称', test_dataset.classes)
训练集路径 /root/autodl-fs/data/fruit20/dataset/train
测试集路径 /root/autodl-fs/data/fruit20/dataset/val
训练集图像数量 1600
类别个数 20
各类别名称 ['哈密瓜', '大枣', '杏', '柑橘', '柚子', '柿子', '桃', '梨', '椰子', '火龙果', '甘蔗', '石榴', '羊角蜜', '芒果', '苹果', '荔枝', '菠萝', '葡萄', '西瓜', '香蕉']
测试集图像数量 400
类别个数 20
各类别名称 ['哈密瓜', '大枣', '杏', '柑橘', '柚子', '柿子', '桃', '梨', '椰子', '火龙果', '甘蔗', '石榴', '羊角蜜', '芒果', '苹果', '荔枝', '菠萝', '葡萄', '西瓜', '香蕉']
'''
3、类别索引 映射字典
'''
# 各类别名称
class_names = train_dataset.classes
n_class = len(class_names)
# 映射关系:类别 到 索引号
train_dataset.class_to_idx
# 映射关系:索引号 到 类别
idx_to_labels = {y:x for x,y in train_dataset.class_to_idx.items()}
idx_to_labels
{0: '哈密瓜',
1: '大枣',
2: '杏',
3: '柑橘',
4: '柚子',
5: '柿子',
6: '桃',
7: '梨',
8: '椰子',
9: '火龙果',
10: '甘蔗',
11: '石榴',
12: '羊角蜜',
13: '芒果',
14: '苹果',
15: '荔枝',
16: '菠萝',
17: '葡萄',
18: '西瓜',
19: '香蕉'}
# 保存为本地的 npy 文件
np.save('idx_to_labels.npy', idx_to_labels)
np.save('labels_to_idx.npy', train_dataset.class_to_idx)
'''
4、加载数据集
'''
from torch.utils.data import DataLoader
BATCH_SIZE = 32
# 训练集的数据加载器
train_iter = DataLoader(train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0
)
# 测试集的数据加载器
test_iter = DataLoader(test_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=0
)
'''
5、微调最后一层,创建resnet-18模型
'''
net = torchvision.models.resnet18(pretrained=True) # 载入预训练模型
# 修改全连接层,使得全连接层的输出与当前数据集类别数对应
# 新建的层默认 requires_grad=True
net.fc = nn.Linear(net.fc.in_features, n_class)
'''
6、模型训练
'''
import torch.nn as nn
from AccumulatorClass import Accumulator
def accuracy(y_hat, y):
"""计算预测正确的数量"""
if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
y_hat = y_hat.argmax(axis=1)
cmp = y_hat.type(y.dtype) == y
return float(cmp.type(y.dtype).sum())
def evaluate_accuracy_gpu(net, data_iter, device=None):
"""使⽤GPU计算模型在数据集上的精度"""
if isinstance(net, nn.Module):
net.eval() # 设置为评估模式
if not device:
device = next(iter(net.parameters())).device
# 正确预测的数量,总预测的数量
metric = Accumulator(2)
with torch.no_grad():
for X, y in data_iter:
if isinstance(X, list):
# BERT微调所需的
X = [x.to(device) for x in X]
else:
X = X.to(device)
y = y.to(device)
metric.add(accuracy(net(X), y), y.numel())
return metric[0] / metric[1]
from AnimatorClass import Animator
from TimerClass import Timer
def train_ch(net, train_iter, test_iter, num_epochs, lr, device):
"""⽤GPU训练模型"""
print('training on', device)
net.to(device)
optimizer = torch.optim.SGD(net.fc.parameters(), lr=lr)
# 只微调训练最后一层全连接层的参数,其它层冻结
# optimizer = torch.optim.Adam(net.fc.parameters())
# 学习率降低策略
# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
# 交叉熵损失
loss = nn.CrossEntropyLoss()
animator = Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])
timer, num_batches = Timer(), len(train_iter)
num_batches = len(train_iter)
best_test_accuracy = 0.0
for epoch in range(num_epochs):
# 训练损失之和,训练准确率之和,样本数
metric = Accumulator(3)
net.train()
for i, (X, y) in enumerate(train_iter):
timer.start()
optimizer.zero_grad()
X, y = X.to(device), y.to(device)
y_hat = net(X)
l = loss(y_hat, y)
l.backward()
optimizer.step()
# lr_scheduler.step()
with torch.no_grad():
metric.add(l * X.shape[0], accuracy(y_hat, y), X.shape[0])
timer.stop()
train_l = metric[0] / metric[2]
train_acc = metric[1] / metric[2]
if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))
test_acc = evaluate_accuracy_gpu(net, test_iter)
if test_acc > best_test_accuracy:
# 删除旧的最佳模型文件(如有)
old_best_checkpoint_path = 'checkpoint/best-{:.3f}.pth'.format(best_test_accuracy)
if os.path.exists(old_best_checkpoint_path):
os.remove(old_best_checkpoint_path)
# 保存新的最佳模型文件
new_best_checkpoint_path = 'checkpoint/best-{:.3f}.pth'.format(test_acc)
torch.save(net, new_best_checkpoint_path)
print('保存新的最佳模型', 'checkpoint/best-{:.3f}.pth'.format(test_acc))
best_test_accuracy = test_acc
animator.add(epoch + 1, (None, None, test_acc))
print(f'best_test_accuracy = {best_test_accuracy:.3f}')
print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, test acc {test_acc:.3f}')
print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on {str(device)}')
# 初始化模型
lr, num_epochs = 0.1, 10
train_ch(net, train_iter, test_iter, num_epochs, lr, try_gpu())
best_test_accuracy = 0.840
loss 0.546, train acc 0.832, test acc 0.810
565.2 examples/sec on cuda:0
best_test_accuracy = 0.840
# 载入最佳模型作为当前模型
net = torch.load('checkpoint/best-{:.3f}.pth'.format(best_test_accuracy))
net.to(try_gpu())
test_iter = DataLoader(test_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0
)
def get_fruit_labels(labels):
"""返回fruit20数据集的⽂本标签"""
text_labels = test_dataset.classes
return [text_labels[int(i)] for i in labels]
def predict(net,test_iter, n=10):
for X,y in test_iter:
trues = get_fruit_labels(y[0:n])
outputs = net(X.to(try_gpu())) # 输入模型,执行前向预测
_, preds = torch.max(outputs, 1)
preds = get_fruit_labels(
preds.cpu().numpy()[0:n]
)
print('trues:',trues)
print('preds:',preds)
break
predict(net,test_iter)
trues: ['菠萝', '梨', '柿子', '大枣', '芒果', '菠萝', '芒果', '苹果', '桃', '菠萝']
preds: ['菠萝', '梨', '柿子', '大枣', '芒果', '菠萝', '芒果', '苹果', '桃', '菠萝']