pytorch实现Resnet系列的分类任务

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录

  • 前言
  • 一、数据集格式
  • 二、训练部分
  • 三、评价
  • 结果
  • 总结


前言

最近需要一系列传统分类方法做对比,所以就顺便把自己复现resnet系列分类实验的过程记录一下,还是老传统:文末有源码。


一、数据集格式

在这里插入图片描述pytorch实现Resnet系列的分类任务_第1张图片
其中training文件夹下的basal、her2都是要分类的类别,basal里面就是一张张图片了。
在这里给大家一个公开的10种猴子的分类数据集,已经分好类了,大家可以直接下载使用。数据集地址:https://www.kaggle.com/slothkong/10-monkey-species

二、训练部分

train.py文件下可以自行修改权值文件保存的地址和batch size:

if __name__ == '__main__':

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    if not os.path.exists('./logs'):
        os.makedirs('./logs')

    BATCH_SIZE = 16

修改数据集的地址

    train_dataset = datasets.ImageFolder("./datasets/training", transform=data_transform["train"])  # 训练集数据
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                                               num_workers=2)  # 加载数据
    len_train = len(train_dataset)
    val_dataset = datasets.ImageFolder("./datasets/validation", transform=data_transform["val"])  # 测试集数据
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                                             num_workers=2)  # 加载数据
    len_val = len(val_dataset)

可以自行选择网络类型,resnet50或者resnet34或者其他都可以,损失函数就只有一个CEloss,优化器使用的是adam,epoch根据自己的数据多少和bacth size自行修改。

    net = resnet50()
    loss_function = nn.CrossEntropyLoss()  # 设置损失函数
    optimizer = optim.Adam(net.parameters(), lr=0.0001)  # 设置优化器和学习率
    epoch = 100

三、评价

改完上面的参数后就可以训练了,训练结束之后可以对于分类结果进行评价,需要续改evaluate.py的相关内容,首先修改训练生产的权值文件的路径。

if __name__ == '__main__':
    model = torch.load("./logs/best.pth")

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    class_correct = [0.] * 10
    class_total = [0.] * 10
    y_test, y_pred = [], []
    X_test = []

下面修改验证集或者测试集的指向路径。

    data_transform = transforms.Compose([transforms.Resize(256),
                                       transforms.CenterCrop(224),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    val_dataset = datasets.ImageFolder("./datasets/validation", transform=data_transform)  # 测试集数据
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                                                 num_workers=2)  # 加载数据

    classes = val_dataset.classes

最后运行evaluate.py就可以。

结果

pytorch实现Resnet系列的分类任务_第2张图片

总结

以上就是今天要讲的内容,本文仅仅简单介绍了resnet分类网络的使用。
源码分享在网盘里面:网盘
提取码:57hs

你可能感兴趣的:(pytorch,分类,深度学习)