Pytorch实现CenterNet—小白学习pytorch

Pytorch实现CenterNet

  • Pytorch实现CenterNet
    • Pycharm训练流程

Pytorch实现CenterNet

   第一次写技术相关的博客,想碎碎念一下:工作一年多以来始终用的是caffe,一直想切换成pytorch,但是没有项目跟进,只看了点中文教程,很难从理解怎样从理论到实际应用的过程。半个月前想了想:“种树的最好时间是十年前和现在”,所以决定手动踩坑。由于可怜打工人只有一块16XX的显卡,训练的batch_size只能放到2,中间拿模型进行测试,结果都是惨惨惨,没想到过了个周末突然收敛了,效果还不错~

Pycharm训练流程

  习惯了caffe的训练方式,突然上手pytorch感觉非常不适应,所以手动才踩坑之前,先查阅了大量的资料,试图对pycharm整体的训练流程有个大概的认识,然后再每个模块各个击破。我对整体需要了流程的理解是:

  1. 数据准备 ,准备好Tensor形式的输入数据和标签;
  2. 前向传播计算网络输出output计算损失函数loss
  3. 反向传播更新参数:(1)将上次迭代计算的梯度清零、(2)反向传播、计算梯度、(3)更新权重参数;
  4. 保存模型和打印需要的信息等;
      此处的数据准备,指的是需要提前想好的计算loss时的需要的label形式,拿我上手的实际例子举例,我要训练的是一个多类别的检测任务。假如我的原始数据标签是:
 XXXXX.jpg  1 1359 417 1405 453  //假如1代表要检测的某一类,后面是xmin xmax ymin ymax

  那么对CenterNet的loss实现中,需要的是其label_offset, label_mask, label_size, label_gauss_heatmap这四个矩阵形式的label,也就是需要写一个数据预处理,将原始数据标签,转换成计算loss需要的上述4个标签。这就是所谓的数据准备

  前向传播计算网络输出,则是指图像经过整个神经网络的输出,也是根据任务定义而来,针对我搭建的CenterNet输出则是 pred_obj_center、pred_obj_offset、pred_obj_size,最后网络输出和数据准备后的标签,一起送进计算loss的模块,计算loss。

  反向传播更新参数:,pytorch是自动计算和更新梯度的,只需要以下几句代码

            #Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

  保存模型也很简单,使用torch的save就可以了。

代码如下:

def train(cfg):
    step = 0
    if not os.path.exists(cfg.model_dir):
        os.mkdir(cfg.model_dir)

    os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpu_id
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    net = centerNet.CenterNetResNet18(cfg.category)
    net.to(device)

    optimizer = optim.SGD(net.parameters(), lr=cfg.base_lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=cfg.step_size, gamma=cfg.gamma)

    # if cfg.recover:
    #     cfg.model_file = 'F:/A01_Python/pytorch/CenterNet/train/model/model_iter_308000.pth'
    #     checkpoint = torch.load(cfg.model_file)
    #     net.load_state_dict(checkpoint['model_state_dict'])


    #前处理
    input_size = [int(x) for x in cfg.input_size.split(',')]
    mean = [float(x) for x in cfg.mean.split(',')]
    std = [float(x) for x in cfg.std.split(',')]
    transform = transforms.Compose([
        transforms.Resize(input_size),
        #transforms.Lambda(lambda img: preprocess.RGB2YUV444(img)),
        transforms.ToTensor(),
        transforms.Normalize(mean= mean,std=std)])

    transformed_dataset = dataset.TrainDataset(cfg.label_file, cfg.image_root_dir, cfg.category, input_size, transform=transform)

    dataloder = DataLoader(transformed_dataset,batch_size=cfg.batch_size,shuffle=False,num_workers=cfg.num_workers)

    while step < cfg.max_steps:
        for i_batch, sample_batched in enumerate(dataloder):
            step +=1
            total_loss = 0.0
            inputs, label_offset, label_mask, label_size, label_gauss_heatmap = sample_batched

            #Put the data on the GPU
            input = inputs.to(device)
            label_gauss_heatmap = label_gauss_heatmap.to(device)
            label_offset = label_offset.to(device)
            label_size = label_size.to(device)
            label_mask = label_mask.to(device)

            #forward
            output = net(input)
            pred_obj_center = output[0]
            pred_obj_offset = output[1]
            pred_obj_size = output[2]

            # Calculate the loss
            loss_center = losses.anchor_free_focalloss(pred_obj_center, label_gauss_heatmap)
            loss_offset = losses.smoothl1loss_with_mask(pred_obj_offset, label_offset, label_mask)
            loss_size = losses.smoothl1loss_with_mask(pred_obj_size, label_size, label_mask)
            loss = loss_center + loss_offset + 0.1 * loss_size

            #Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            #save the model and print the information
            total_loss += loss.item()
            if step % cfg.save_steps == 0:
                model_pathname = cfg.model_dir + '/model_iter_' + str(step) + '.pth'
                torch.save({'model_state_dict': net.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'step': step
                            }, model_pathname)

                if step % cfg.print_steps == 0:
                    lr = '{:g}'.format(scheduler.get_last_lr()[0])
                    print_info = 'step = %d, lr = %s, loss = %.6f\n' \
                                 'loss_center = %.6f\nloss_offset = %.6f\nloss_size = %.6f\n' \
                                 % (step, lr, total_loss, \
                                    loss_center, loss_offset, loss_size)
                    print_info += '-----------------------------------------------------\n'
                    with open(cfg.accuracy_file, 'a') as f:
                        f.write(print_info)
                    print(print_info)

                scheduler.step()

                if step >= cfg.max_steps:
                    break

你可能感兴趣的:(Pytorch,人工智能,python)