加载数据并转换为tensor格式
'''DATA LOADING'''
log_string('Load dataset ...')
DATA_PATH = 'data/modelnet40_normal_resampled/' #点云分类数据相对路径
TRAIN_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='train',
normal_channel=args.normal)
# 训练集:9843个样本
TEST_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='test',
normal_channel=args.normal)
# 测试集:2468个样本
trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=4) #训练集转tensor格式
testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=4) #测试集转tensor格式
模型加载并设置初始参数
'''MODEL LOADING'''
# 分类类别数目
num_class = 40
# import network module
MODEL = importlib.import_module(args.model)
shutil.copy('./models/%s.py' % args.model, str(experiment_dir)) #复制需要训练的模型的模块,保存至同级目录下的新文件夹models
shutil.copy('./models/pointnet_util.py', str(experiment_dir)) #复制预处理模块,保存至同级目录下的新文件夹models
classifier = MODEL.get_model(num_class,normal_channel=args.normal).cuda() #获取模型,模型变量命名为classifier,并调用gpu
criterion = MODEL.get_loss().cuda() #获取模型损失,损失变量命名为criterion,并调用gpu
#---------判断是否利用预训练模型------------#
try:
checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
start_epoch = checkpoint['epoch']
classifier.load_state_dict(checkpoint['model_state_dict'])
log_string('Use pretrain model')
except:
log_string('No existing model, starting training from scratch...')
start_epoch = 0
#--------设置优化器参数-----------------#
if args.optimizer == 'Adam':
optimizer = torch.optim.Adam(
classifier.parameters(),
lr=args.learning_rate,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=args.decay_rate
)
else:
optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
#----------设置训练初始参数---------------#
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
global_epoch = 0
global_step = 0
best_instance_acc = 0.0
best_class_acc = 0.0
mean_correct = []
模型训练过程中需要对模型的每一个epoch训练效果进行评估,
每一个epoch内包含多个batch,需要对每个batch_id的data进行提取,进而对每个batch_id下的数据进行模型训练得到correct并存入mean_correct列表,
最后通过均值np.mean(mean_correct)得到train_instance_acc(每一个epoch下的准确率),并通过log_string存入日志文件中(记录每个epoch的准确率),
每训练完一个epoch,调用instance_acc, class_acc = test(classifier.eval(), testDataLoader)进行评估,
# 性能评估
with torch.no_grad():
instance_acc, class_acc = test(classifier.eval(), testDataLoader)
if (instance_acc >= best_instance_acc):
best_instance_acc = instance_acc
best_epoch = epoch + 1
if (class_acc >= best_class_acc):
best_class_acc = class_acc
log_string('Test Instance Accuracy: %f, Class Accuracy: %f'% (instance_acc, class_acc))
log_string('Best Instance Accuracy: %f, Class Accuracy: %f'% (best_instance_acc, best_class_acc))
if (instance_acc >= best_instance_acc):
logger.info('Save model...')
savepath = str(checkpoints_dir) + '/best_model.pth'
log_string('Saving at %s'% savepath)
state = {
'epoch': best_epoch,
'instance_acc': instance_acc,
'class_acc': class_acc,
'model_state_dict': classifier.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}
# 保存网络模型
torch.save(state, savepath)
global_epoch += 1