提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
最近需要一系列传统分类方法做对比,所以就顺便把自己复现resnet系列分类实验的过程记录一下,还是老传统:文末有源码。
其中training文件夹下的basal、her2都是要分类的类别,basal里面就是一张张图片了。
在这里给大家一个公开的10种猴子的分类数据集,已经分好类了,大家可以直接下载使用。数据集地址:https://www.kaggle.com/slothkong/10-monkey-species
train.py文件下可以自行修改权值文件保存的地址和batch size:
if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
if not os.path.exists('./logs'):
os.makedirs('./logs')
BATCH_SIZE = 16
修改数据集的地址
train_dataset = datasets.ImageFolder("./datasets/training", transform=data_transform["train"]) # 训练集数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=2) # 加载数据
len_train = len(train_dataset)
val_dataset = datasets.ImageFolder("./datasets/validation", transform=data_transform["val"]) # 测试集数据
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=2) # 加载数据
len_val = len(val_dataset)
可以自行选择网络类型,resnet50或者resnet34或者其他都可以,损失函数就只有一个CEloss,优化器使用的是adam,epoch根据自己的数据多少和bacth size自行修改。
net = resnet50()
loss_function = nn.CrossEntropyLoss() # 设置损失函数
optimizer = optim.Adam(net.parameters(), lr=0.0001) # 设置优化器和学习率
epoch = 100
改完上面的参数后就可以训练了,训练结束之后可以对于分类结果进行评价,需要续改evaluate.py的相关内容,首先修改训练生产的权值文件的路径。
if __name__ == '__main__':
model = torch.load("./logs/best.pth")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class_correct = [0.] * 10
class_total = [0.] * 10
y_test, y_pred = [], []
X_test = []
下面修改验证集或者测试集的指向路径。
data_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
val_dataset = datasets.ImageFolder("./datasets/validation", transform=data_transform) # 测试集数据
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=2) # 加载数据
classes = val_dataset.classes
最后运行evaluate.py就可以。
以上就是今天要讲的内容,本文仅仅简单介绍了resnet分类网络的使用。
源码分享在网盘里面:网盘
提取码:57hs