ResNet训练单通道图像分类网络(Pytorch)

前言

 ResNet是一个比较成熟的深度学习分类模型,目前有ResNet-18、ResNet-34、ResNet-50、ResNet-101、ResNet-152,同时,该分类模型常用于RGB(三通道)彩色图像的分类任务,如在ImageNet的训练;而在单通道图像(灰度图像)的训练和测试较少。如何使ResNet在单通道图像上训练,如何修改网络模型参数和读取图像,本文将一一进行讲解。

步骤

第一步:构建数据集

  • 数据集的结构应该是这样的
    ResNet训练单通道图像分类网络(Pytorch)_第1张图片
  • 图像的格式:8bit,jpg格式

第二步:修改网络模型

  • 法1:直接修改定义的ResNet网络模型
     在model.py中,修改ResNet的第一层卷积层输入通道为1(彩色为3)
self.conv1 = nn.Conv2d(1, self.in_channel, kernel_size=7, stride=2,padding=3, bias=False)
  • 法2:在train.py文件中,进行如下修改,也可以达到法1的效果
model = resnet18(num_classes=3)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model = model.to(device)

第三步:修改读取数据方式

  • 一般我们用torchvision.datasets.ImageFolder()读取数据,但在读取单通道数据时,此函数会自动将单通道图像转换为三通道图像(r=g=b),此时如果不进行其他操作,就会报错

  • 这是ImageFolder()函数的定义:留意读取的图像为PIL图像,且会转换为RGB格式
    ResNet训练单通道图像分类网络(Pytorch)_第2张图片
    在这里插入图片描述

  • 修改方法:
     修改transform(图像预处理操作)
      添加transforms.Grayscale(1),将图像转换为单通道图像(经实验,图像矩阵的数据并不会发生变化)
      transforms.Normalize修改如下,第一个参数为mean,第二个参数为std,因为是单通道,所以进行Z-Score时仅需要对一个通道进行操作
    ResNet训练单通道图像分类网络(Pytorch)_第3张图片

data_transforms = {
     
    'train': transforms.Compose([
        transforms.Grayscale(1),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, ], [0.229, ])
    ]),
    'val': transforms.Compose([
        transforms.Grayscale(1),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, ], [0.229, ])
    ])
}

第四步:训练分类网络并测试(注意测试时transform与‘val’方式一样)

总结

 ResNet训练单通道主要修改两个部分,一个是ResNet模型第一层卷积层的in_channels=1,另一个是transform中添加Grayscale(1)以及修改Normalize。其实很简单,只是有时忽略了ImageFolder会自动将灰度图转换为RGB图,导致出错,希望本文能帮助您!

参考资料:

可以参考这位up主github里面的Test5_resnet,并在此基础上进行上述修改,训练自己的灰度图像!
https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/tree/master/pytorch_learning

你可能感兴趣的:(pytorch)