针对刚接触深度学习的小伙伴,肯定很想自己亲手搭建一个网络模型,训练模型。今天作者就五步教大家简单快速搭建一个分类网络,并训练模型,希望对初学者有一定帮助,欢迎大家收藏关注,作者将不断分享更新深度学习中的一些重要知识点。
作者是训练分类工作服和非工作服,如下图所示,为数据格式排布,很简单。其中文件夹“1”中存放的是没有穿工作服的图片,“0”文件夹中存放的是穿工作服的图片;train_1.txt 中的内容为“1”文件夹下没有穿工作服的图片路径和标签(每一行末尾设置为1),同理,想必大家也清楚train_0.txt内容了吧,两个txt文件由Build_path_toTXT.py生成而来,直接上图。
train_1.txt里的内容如下图
if __name__ == '__main__':
#Config类的对象,Config类中定义了众多参数,学习率、训练数据路径、梯度下降方法,相当于C语言中的宏定义
opt = Config()
#定义使用哪个GPU训练
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
#gpu训练还是cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#测试数据加载,分别用torch.utils.data.dataset与torch.utils.data.DataLoader API来完成,具体参数可以看源码定义,有解释,很好理解
test_dataset = Dataset(opt.test_root, opt.test_list, phase='test', input_shape=opt.input_shape)
testloader = data.DataLoader(test_dataset,
shuffle=False,
pin_memory=True,
num_workers=opt.num_workers)
#训练数据加载
train_dataset = Dataset(opt.train_root, opt.train_list, phase='train', input_shape=opt.input_shape)
trainloader = data.DataLoader(train_dataset,
batch_size=opt.train_batch_size,
shuffle=True,
pin_memory=True,
num_workers=opt.num_workers)
代码如下(示例):
if opt.loss == 'focal_loss':
#焦点损失函数,自定义
criterion = FocalLoss(gamma=2)
else:
#pytorch自带交叉熵损失函数,直接用
criterion = torch.nn.CrossEntropyLoss()
作者直接使用的是resnet34网络,该网络在有很多开源代码,可以参考自己写,也可以直接拿来用。代码如下(示例):
model = resnet34()
#是否有预训练模型
if opt.finetune == True:
model.load_state_dict(torch.load(opt.load_model_path))
#多GPU训练
model = torch.nn.DataParallel(model)
#导入到gpu
model.to(device)
定义梯度下降方法,学习率调整策略等。作者使用的是SGD梯度下降算法和余弦退火方法,都可以直接从Pytorch中调用。代码如下(示例):
#总batch数
total_batch = len(trainloader)
NUM_BATCH_WARM_UP = total_batch * 5
#SGD梯度下降算法
optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)
#余弦退火
scheduler = CosineDecayLR(optimizer, opt.max_epoch * total_batch, opt.lr, 1e-6, NUM_BATCH_WARM_UP)
迭代训练,一个for循环搞定。代码如下(示例):
print('{} train iters per epoch in dataset'.format(len(trainloader)))
for epoch in range(0, opt.max_epoch):
#开始训练迭代,train()需要另外实现
train(model, criterion, optimizer, scheduler, trainloader, epoch)
if epoch % opt.save_interval == 0 or epoch == (opt.max_epoch - 1):
#保存模型
torch.save(model.module.state_dict(), 'checkpoints/model-epoch-'+str(epoch) + '.pth')
#验证
eval_train(model, criterion, testloader)
简单的分类网络,基本就这五个步骤搞定,这只是作者项目中的部分代码,仅供大家参考学习整个流程,欢迎大家收藏关注,转发!后续作者会不断更新
转发该博文务必注明出处