ResNet50修改网络适应灰度图片并加载预训练模型

此博文是修改https://blog.csdn.net/jiacong_wang/article/details/105631229
这位大大的博文而成的,自己根据自己的情况稍微加了点东西

要修改的地方有4处

1.修改网络第一层,把3通道改为1
法一:直接在定义网络的地方修改

self.conv1 = nn.Conv2d(1, self.in_channel, kernel_size=7, stride=2,padding=3, bias=False)

法二:在调用网络模型的地方修改

model = resnet50()
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model = model.to(device)

2.修改读取数据的方式
--------修改之前:
ResNet50修改网络适应灰度图片并加载预训练模型_第1张图片
--------修改之后

 train_transformer = transforms.Compose([
        transforms.Grayscale(1), # 修改
        transforms.RandomHorizontalFlip(0.5),  
        transforms.ToTensor(),                
        transforms.Normalize(0.485, 0.229, inplace=True),  # 修改(-1,1)
    ])
  • 修改方法
  1. 修改transform(图像预处理操作)
      添加transforms.Grayscale(1),将图像转换为单通道图像(经实验,图像矩阵的数据并不会发生变化)
  2. transforms.Normalize修改如下,第一个参数为mean,第二个参数为std,因为是单通道,所以进行Z-Score时仅需要对一个通道进行操作,所以mean和std只需要一个值就行

3.修改读取数据集部分(mydataset.py)
-----修改之前
ResNet50修改网络适应灰度图片并加载预训练模型_第2张图片
-----修改之后

img = Image.open(path_img)

只需要把后面转RGB的部分去掉就行

4.因为加载预训练模型而修改网络
要加载预训练模型,第一层的权重参数肯定不能加载,则只需要把第一层的权重参数避开就行:

  • 加载方法
    net = resnet50()
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet50-19c8e357.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)

    # 加载预训练模型并且把不需要的层去掉
    pre_state_dict = torch.load(model_weight_path)
    print("原模型", pre_state_dict.keys())
    new_state_dict = {}
    for k, v in net.state_dict().items():          # 遍历修改模型的各个层
        print("新模型", k)
        if k in pre_state_dict.keys() and k!= 'conv1.weight':
            new_state_dict[k] = pre_state_dict[k]  # 如果原模型的层也在新模型的层里面, 那新模型就加载原先训练好的权重
    net.load_state_dict(new_state_dict, False)
    # for param in net.parameters():
    #     param.requires_grad = False

    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 5)
    net.to(device)

你可能感兴趣的:(ResNet50修改网络适应灰度图片并加载预训练模型)