记一次目标分类任务流程拉通——pytorch+resnet50+重新筛选的部分challenge2018农作物病害数据集(玉米)

做这个事情是为了拉通流程,所以对数据集进行了重新制作,只选取了部分类别来完成目标分类任务。
代码地址:https://github.com/AlannahYYL/plant_classification
完整数据集:https://pan.baidu.com/wap/init?surl=6f1nQchS-zBtzSWn9Guyyg 提取码:iksk
筛选后的玉米:https://pan.baidu.com/s/1GLo51C_y2pcoESDFUrgyDA 提取码:sdfc

一.项目环境

python==3.6   torch==1.6   torchvision==0.7.0

二.数据集介绍

首先说下数据集:AI Challenger 2018农作物病害检测竞赛就是由上海新客科技为竞赛提供农作物叶子图像的数据集:标注图片5万张,包含10种植物(苹果、樱桃、葡萄、柑桔、桃、草莓、番茄、辣椒、玉米、马铃薯)的27种病害,合计61个分类(按“物种-病害-程度”分)
标签类别对照表:见https://github.com/xungeer29/AI-Challenger-Plant-Disease-Recognition/blob/master/README.md
标注文件为json文件,是一个列表,列表中每张图片的信息以字典形式保存。key值:“disease_class”, “image_id”。

三.数据集制作

首先参考这篇进行的数据分布的分析,选择了9-16玉米的病害数据来完成本次任务。

with open(train_data_json) as datafile1:
    trainDataFram=pd.read_json(datafile1,orient='records')
with open(val_data_json) as datafile2: #first check if it's a valid json file or not
    validateDataFram =pd.read_json(datafile2,orient='records')
total=trainDataFram.isnull().sum().sort_values(ascending=False)
percent=(trainDataFram.isnull().sum())/(trainDataFram.isnull().count()).sort_values(ascending = False)
missing_validation_data = pd.concat([total, percent], axis=1, keys=['Total', 'Percent'],sort=False)
# print(missing_validation_data.head())
dataDistribute=trainDataFram.groupby(by=['disease_class']).size()
# print(dataDistribute)
plt.figure(figsize=(50,20),dpi=100)
plt.xticks(range(len(dataDistribute)),dataDistribute.index.tolist(),fontsize=40)
plt.yticks(fontsize=40)
bar=plt.bar(dataDistribute.index.tolist(), dataDistribute.tolist(),width=0.7)
for b in bar:
    h=b.get_height()
    plt.text(b.get_x()+b.get_width()/2,h,int(h),ha='center',fontsize=30)
plt.show()

记一次目标分类任务流程拉通——pytorch+resnet50+重新筛选的部分challenge2018农作物病害数据集(玉米)_第1张图片
数据集制作使用的pandas进行数据筛选,然后生成train和val文件夹,里面包含标签9-16命名的图片文件夹,然后将对应类别的图片导入文件夹,代码简单详见git。

四.训练

训练脚本来自:https://github.com/pytorch/examples/tree/master/imagenet
根据自己的情况进行修改:
1.开头argparse添加数据集路径

parser.add_argument('-data', default='/opt/yyl/data/plant2018/', metavar='DIR',help='path to dataset')

2.简便起见,resnet50写入default

parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',choices=model_names, help='model architecture: ' +' | '.join(model_names) +' (default: resnet18)')

3.使用预训练权重

parser.add_argument('--pretrained', default='True',dest='pretrained', action='store_true', help='use pre-trained model')

4.创建模型后面修改全连接层输出

#fc modified 
model.fc = torch.nn.Linear(in_features=2048,out_features=8,bias=True)
print(model)

5.同样在开头设置batchsize、epoch、指定gpu等就不赘述
好了,在根目录下创建一个models文件夹,然后直接运行train.py就可以开始训练了。

本次训练的最优结果:

在这里插入图片描述

青古の每篇一歌
《外婆》
外婆她的无奈
无法变成期待
只有爱才能够明白

你可能感兴趣的:(ML/DL/数据结构与算法,python,深度学习,python)