pointnet数据加载、模型加载和训练部分理解(pytorch)

pointnet的数据加载和模型加载

  • 一、数据加载
  • 二、模型加载并设置初始参数
  • 三、模型训练

一、数据加载

加载数据并转换为tensor格式

'''DATA LOADING'''
log_string('Load dataset ...')
DATA_PATH = 'data/modelnet40_normal_resampled/'  #点云分类数据相对路径

TRAIN_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='train', 
                                                 normal_channel=args.normal)
# 训练集:9843个样本                                                 
TEST_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='test',
                                                normal_channel=args.normal)
# 测试集:2468个样本
trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=4)  #训练集转tensor格式
testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=4)   #测试集转tensor格式

二、模型加载并设置初始参数

模型加载并设置初始参数

'''MODEL LOADING'''
# 分类类别数目
num_class = 40
# import network module
MODEL = importlib.import_module(args.model)
shutil.copy('./models/%s.py' % args.model, str(experiment_dir)) #复制需要训练的模型的模块,保存至同级目录下的新文件夹models
shutil.copy('./models/pointnet_util.py', str(experiment_dir))  #复制预处理模块,保存至同级目录下的新文件夹models

classifier = MODEL.get_model(num_class,normal_channel=args.normal).cuda()  #获取模型,模型变量命名为classifier,并调用gpu
criterion = MODEL.get_loss().cuda()  #获取模型损失,损失变量命名为criterion,并调用gpu

#---------判断是否利用预训练模型------------#
try:
    checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
    start_epoch = checkpoint['epoch']
    classifier.load_state_dict(checkpoint['model_state_dict'])
    log_string('Use pretrain model')
except:
    log_string('No existing model, starting training from scratch...')
    start_epoch = 0

#--------设置优化器参数-----------------#
if args.optimizer == 'Adam':
    optimizer = torch.optim.Adam(
        classifier.parameters(),
        lr=args.learning_rate,
        betas=(0.9, 0.999),
        eps=1e-08,
        weight_decay=args.decay_rate
    )
else:
    optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)

#----------设置训练初始参数---------------#
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
global_epoch = 0
global_step = 0
best_instance_acc = 0.0
best_class_acc = 0.0
mean_correct = []

三、模型训练

模型训练过程中需要对模型的每一个epoch训练效果进行评估,
每一个epoch内包含多个batch,需要对每个batch_id的data进行提取,进而对每个batch_id下的数据进行模型训练得到correct并存入mean_correct列表,
最后通过均值np.mean(mean_correct)得到train_instance_acc(每一个epoch下的准确率),并通过log_string存入日志文件中(记录每个epoch的准确率),
pointnet数据加载、模型加载和训练部分理解(pytorch)_第1张图片
每训练完一个epoch,调用instance_acc, class_acc = test(classifier.eval(), testDataLoader)进行评估,

# 性能评估
with torch.no_grad():
instance_acc, class_acc = test(classifier.eval(), testDataLoader)

if (instance_acc >= best_instance_acc):
    best_instance_acc = instance_acc
    best_epoch = epoch + 1

if (class_acc >= best_class_acc):
    best_class_acc = class_acc
log_string('Test Instance Accuracy: %f, Class Accuracy: %f'% (instance_acc, class_acc))
log_string('Best Instance Accuracy: %f, Class Accuracy: %f'% (best_instance_acc, best_class_acc))

if (instance_acc >= best_instance_acc):
    logger.info('Save model...')
    savepath = str(checkpoints_dir) + '/best_model.pth'
    log_string('Saving at %s'% savepath)
    state = {
        'epoch': best_epoch,
        'instance_acc': instance_acc,
        'class_acc': class_acc,
        'model_state_dict': classifier.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }
    # 保存网络模型
    torch.save(state, savepath)
global_epoch += 1

你可能感兴趣的:(python)