工业缺陷检测项目实战(四)——基于HRNet的陶瓷缺陷检测

基于HRNet的陶瓷缺陷检测

1.原理:
参考大佬们的文章
HRNet: HRNet原理.

2.数据集准备和代码
数据下载链接:https://aistudio.baidu.com/aistudio/datasetdetail/32615
代码下载链接:https://gitee.com/wxyfmq123456/HRNet-Image-Classification?_from=gitee_search
3.原图二值化
这里数据集已经提供了二值化图像的png,我们需要用png图像进行训练。因为原图像特别不明显。
工业缺陷检测项目实战(四)——基于HRNet的陶瓷缺陷检测_第1张图片
总共6个类别。
4. 参数配置
(1) 数据存放位置:
工业缺陷检测项目实战(四)——基于HRNet的陶瓷缺陷检测_第2张图片
(2) 数据存放方式:
工业缺陷检测项目实战(四)——基于HRNet的陶瓷缺陷检测_第3张图片
每个文件夹代表一个类型,里面的图片全部都是二值化图片(.png),原图(.jpg)可以删去或者备份在其他地方。

(3) 修改代码
工业缺陷检测项目实战(四)——基于HRNet的陶瓷缺陷检测_第4张图片
打开cls_hrnet.py,修改

self.classifier = nn.Linear(2048, 1000)

self.classifier = nn.Linear(2048, 6)

也就是后面的参数为类别数

(4) 选择配置文件
工业缺陷检测项目实战(四)——基于HRNet的陶瓷缺陷检测_第5张图片
我们选择第一个,即:
cls_hrnet_w18_sgd_lr5e-2_wd1e-4_bs32_x100.yaml

修改里面的参数,将

DATASET:
  DATASET: 'imagenet'
  DATA_FORMAT: 'jpg'
  ROOT: 'data/imagenet/'
  TEST_SET: 'val'
  TRAIN_SET: 'train'

修改为

DATASET:
  DATASET: 'data'
  DATA_FORMAT: 'png'
  ROOT: 'imagenet'
  TEST_SET: 'val'
  TRAIN_SET: 'train'

原因是我们设置的路径跟代码的不一样。
其他参数,比如迭代次数epoch,bath_size等,可以自己调参。

(5) 在train.py中增加如下代码

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(config.MODEL.IMAGE_SIZE[0]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    )
 
 
    #以下为增加的代码,上面几行是原有的代码
    #print(train_dataset.classes)  #根据分的文件夹的名字来确定的类别
    with open("class.txt","w") as f1:
        for classname in train_dataset.classes:
            f1.write(classname + "\n")
 
    #print(train_dataset.class_to_idx) #按顺序为这些类别定义索引为0,1...
    with open("classToIndex.txt", "w") as f2:
        for key, value in train_dataset.class_to_idx.items():
            f2.write(str(key) + " " + str(value) + '\n')
 
    #print(train_dataset.imgs) #返回从所有文件夹中得到的图片的路径以及其类别

可以保持类别对应的index。

5.训练

python  tools/train.py --cfg  experiments/cls_hrnet_w18_sgd_lr5e-2_wd1e-4_bs32_x100.yaml

训练完在output文件夹里面有
权重偏置文件:final_state.pth.tar

6.测试
测试的图片,注意这里还是读取验证集vaild,所以为了测试一张图片,我们可以把验证集的图片变为一张,放在例如名为test的文件夹里面,路径如图:
工业缺陷检测项目实战(四)——基于HRNet的陶瓷缺陷检测_第6张图片
在HRNet-Image-Classification-master\lib\core\function.py里面的def validate函数,添加

print('class:{}'.format(output.argmax(1)))

以打印识别的类别index。
运行:

python  tools/vaild.py --cfg  experiments/cls_hrnet_w18_sgd_lr5e-2_wd1e-4_bs32_x100.yaml

会打印出所属类别的index。

你可能感兴趣的:(机器学习与深度学习,paddlepaddle,计算机视觉,python)