关于用pytorch构建vgg网络实现花卉分类的学习笔记

需要的第三方库:

pytorch、matplotlib、json、os、tqdm

一、model.py的编写

(1)准备工作

1.参照vgg网络结构图(如下图1),定义一个字典,用于存放各种vgg网络,字典如下图2(M表示最大池化层)
关于用pytorch构建vgg网络实现花卉分类的学习笔记_第1张图片
在这里插入图片描述
2.定义一个获取特征的函数,此处命名为make_features,参数为模型名字,再遍历字典中键对应的值列表,向layers中加入对应的卷积层和池化层,最后返回打包完成的feature(非关键字参数),用于后续操作
关于用pytorch构建vgg网络实现花卉分类的学习笔记_第2张图片

(2)VGG类的定义

创建一个VGG类,父类为nn.Module,初始化函数的参数中设置:feature(包含网络中卷积与池化各层,并用nn.Sequential打包完成)、num_classes(对应类别个数,即全连接层最后的节点个数,设为1000);
再在初始化函数中编写三个全连接层,并打包,如下图:
全连接层之间先使用relu函数激活,再使用dropout,使一半的神经元随机失活,防止过拟合
关于用pytorch构建vgg网络实现花卉分类的学习笔记_第3张图片
然后定义其正向传播过程,将features从维度为1展平,再放入classifier,然后返回值
关于用pytorch构建vgg网络实现花卉分类的学习笔记_第4张图片
初始化权重
利用for循环,对卷积和池化层分别进行权重初始化,并对偏置量bias置0
关于用pytorch构建vgg网络实现花卉分类的学习笔记_第5张图片

vgg函数的编写

定义函数名为vgg的函数,参数为model_name和可变长度参数,用于直接实例化VGG类对象,返回对应的model
关于用pytorch构建vgg网络实现花卉分类的学习笔记_第6张图片

二、train.py的编写

定义main函数,自动调用,完成训练与验证

验证是否使用Gpu:
在这里插入图片描述

(1)数据预处理

定义一个字典,存放对训练集和验证集的处理;
训练的图像经过随机裁剪,随机水平反转,数据转换为tensor格式,以及对数据进行标准化处理;验证集则需要强行变为224*224的图像,再转换为tensor格式,最后标准化
关于用pytorch构建vgg网络实现花卉分类的学习笔记_第7张图片


(2)相关数据集位置读取

在这里插入图片描述
以上一行代码可指定当前py文件上两级目录的绝对位置(以下图为例,train.py存放在VGG_ pytorch,则data_root即为projects文件夹的绝对位置)
关于用pytorch构建vgg网络实现花卉分类的学习笔记_第8张图片

根据数据集的存放位置,利用os.path.join拼接得到图片的路径,再分别对应训练和验证集进行打开数据集位置并进行预处理,再加载数据集,以下以训练集为例,测试集代码与之类似,不做展示,需要注意的是,num_workers在Windows平台一般不能置为非0数字,若是Linux等平台可进行修改
关于用pytorch构建vgg网络实现花卉分类的学习笔记_第9张图片

利用.class_to_idx方法生成以类别为键,数字为值的字典,再将其键值交换,写入json文件,效果如下:
关于用pytorch构建vgg网络实现花卉分类的学习笔记_第10张图片


(3)模型实例化

利用model.py中的vgg函数实例化一个net对象,使之成为训练所使用的model,损失函数选择交叉熵,优化器选择Adam,并将learning rate(学习率)设为0.0001(下图以vgg16为例,init_weights是model.py中定义的初始化权重),再定义存放权重文件的路径
关于用pytorch构建vgg网络实现花卉分类的学习笔记_第11张图片

利用tqdm对加载的训练集进行处理,再进行遍历,对其进行梯度置0,输出置于GPU,计算预测值与真实值的loss,再将loss反向传播置每一节点,最后根据loss更新参数(注意要先利用net.train()启用dropout);如下图:
关于用pytorch构建vgg网络实现花卉分类的学习笔记_第12张图片

对于验证集的处理有部分不同:
1.先使用.eval()禁用dropout,同时禁止计算损失梯度关于用pytorch构建vgg网络实现花卉分类的学习笔记_第13张图片
2.循环中预测值输出为每行最大值,即可能性最高的预测值;
对预测值和真实值进行判断,若相等,则acc+1,不相等则不加
在这里插入图片描述

再用acc/验证集总个数得到验证集准确率,与之前迭代产生的验证集准确率作比较,在结尾处书写如下代码,则使得最终保存acc最高的权重数据
在这里插入图片描述

三、 predict.py的编写

首先对需要的数据进行如train.py中验证集一样的数据预处理,然后直接打开展示待检测图像,再为其添加batch维度(如下图)
关于用pytorch构建vgg网络实现花卉分类的学习笔记_第14张图片
然后读取之前写入的json文件,初始化网络,载入权重文件,完成网络模型的载入;
关于用pytorch构建vgg网络实现花卉分类的学习笔记_第15张图片
再禁用dropout、禁止计算损失梯度,将图片放入模型再压缩掉batch维度,最终得到输出,再通过softmax得到其概率分布,最后通过 torch.argmax()得到概率最大处的索引值,打印出结果
关于用pytorch构建vgg网络实现花卉分类的学习笔记_第16张图片

四、过程中的一些问题

1.运行train.py时,出现报错

OSError: [WinError 1455] 页面文件太小,无法完成操作。

解决方案:
https://blog.csdn.net/qq_17755303/article/details/112564030

2.运行train.py时,出现报错

RuntimeError: CUDA out of memory. Tried to allocate 50.00 MiB (GPU 0; 6.00 G

其原因为batch_size设置过大,初始值设置为32,出现上述报错,改为16之后,便正常运行

附上效果图:
在这里插入图片描述
关于用pytorch构建vgg网络实现花卉分类的学习笔记_第17张图片

你可能感兴趣的:(关于用pytorch构建vgg网络实现花卉分类的学习笔记)