Trainer类初始化部分——
class Trainer(object):
def __init__(self, args):
self.args = args
# Define Saver
self.saver = Saver(args)
self.saver.save_experiment_config()
# Define Tensorboard Summary
self.summary = TensorboardSummary(self.saver.experiment_dir)
self.writer = self.summary.create_summary()
self.use_amp = True if (APEX_AVAILABLE and args.use_amp) else False
self.opt_level = args.opt_level
kwargs = {'num_workers': args.workers, 'pin_memory': True, 'drop_last':True}
# 加载数据部分。当前模式为search,因此将训练集分成A,B两部分。
self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
# 判断是否需要使用已有的权重。【balanced是啥意思?】
if args.use_balanced_weights:
# 设置权重路径
classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
# 加载权重
if os.path.isfile(classes_weights_path):
weight = np.load(classes_weights_path)
else:
raise NotImplementedError
#if so, which trainloader to use?
# weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
# 将np格式的权重转换成torch的tensor格式
weight = torch.from_numpy(weight.astype(np.float32))
else:
weight = None
# 设置衡量标准,默认的args.loss_type为'ce'
self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
# 定义AutoDeeplab网络结构
model = AutoDeeplab (self.nclass, 12, self.criterion, self.args.filter_multiplier,
self.args.block_multiplier, self.args.step)
# 优化model参数时,采用SGD随机梯度下降方法。
optimizer = torch.optim.SGD(
model.weight_parameters(),
args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay
)
self.model, self.optimizer = model, optimizer
# 优化结构参数arch_parameters时,采用Adam优化算法。
self.architect_optimizer = torch.optim.Adam(self.model.arch_parameters(),
lr=args.arch_lr, betas=(0.9, 0.999),
weight_decay=args.arch_weight_decay)
# 定义模型表现评估类。Evaluator的类方法中,包含MIOU指标的计算方法。
self.evaluator = Evaluator(self.nclass)
# Define lr scheduler
self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
args.epochs, len(self.train_loaderA), min_lr=args.min_lr)
# TODO: Figure out if len(self.train_loader) should be devided by two ? in other module as well
# Using cuda
if args.cuda:
self.model = self.model.cuda()
# mixed precision
if self.use_amp and args.cuda:
keep_batchnorm_fp32 = True if (self.opt_level == 'O2' or self.opt_level == 'O3') else None
# fix for current pytorch version with opt_level 'O1'
if self.opt_level == 'O1' and torch.__version__ < '1.3':
for module in self.model.modules():
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
# Hack to fix BN fprop without affine transformation
if module.weight is None:
module.weight = torch.nn.Parameter(
torch.ones(module.running_var.shape, dtype=module.running_var.dtype,
device=module.running_var.device), requires_grad=False)
if module.bias is None:
module.bias = torch.nn.Parameter(
torch.zeros(module.running_var.shape, dtype=module.running_var.dtype,
device=module.running_var.device), requires_grad=False)
# print(keep_batchnorm_fp32)
self.model, [self.optimizer, self.architect_optimizer] = amp.initialize(
self.model, [self.optimizer, self.architect_optimizer], opt_level=self.opt_level,
keep_batchnorm_fp32=keep_batchnorm_fp32, loss_scale="dynamic")
print('cuda finished')
# Using data parallel
if args.cuda and len(self.args.gpu_ids) >1:
if self.opt_level == 'O2' or self.opt_level == 'O3':
print('currently cannot run with nn.DataParallel and optimization level', self.opt_level)
self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
patch_replication_callback(self.model)
print('training on multiple-GPUs')
#checkpoint = torch.load(args.resume)
#print('about to load state_dict')
#self.model.load_state_dict(checkpoint['state_dict'])
#print('model loaded')
#sys.exit()
# Resuming checkpoint
self.best_pred = 0.0
if args.resume is not None:
if not os.path.isfile(args.resume):
raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
# if the weights are wrapped in module object we have to clean it
if args.clean_module:
self.model.load_state_dict(checkpoint['state_dict'])
state_dict = checkpoint['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove 'module.' of dataparallel
new_state_dict[name] = v
# self.model.load_state_dict(new_state_dict)
copy_state_dict(self.model.state_dict(), new_state_dict)
else:
if torch.cuda.device_count() > 1 or args.load_parallel:
# self.model.module.load_state_dict(checkpoint['state_dict'])
copy_state_dict(self.model.module.state_dict(), checkpoint['state_dict'])
else:
# self.model.load_state_dict(checkpoint['state_dict'])
copy_state_dict(self.model.state_dict(), checkpoint['state_dict'])
if not args.ft:
# self.optimizer.load_state_dict(checkpoint['optimizer'])
copy_state_dict(self.optimizer.state_dict(), checkpoint['optimizer'])
self.best_pred = checkpoint['best_pred']
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
# Clear start epoch if fine-tuning
if args.ft:
args.start_epoch = 0
def training(self, epoch):
# 初始的训练误差为0
train_loss = 0.0
# 训练AutoDeeplab模型。
self.model.train()
# 定义进度条。
tbar = tqdm(self.train_loaderA)
num_img_tr = len(self.train_loaderA)
for i, sample in enumerate(tbar):
# 获取当前样本的图像和label。
image, target = sample['image'], sample['label']
if self.args.cuda:
image, target = image.cuda(), target.cuda()
# 这一步不太理解——
self.scheduler(self.optimizer, i, epoch, self.best_pred)
# 将所有变量的梯度清零。
self.optimizer.zero_grad()
# 输入image,获取AutoDeeplab模型的输出output。
output = self.model(image)
# criterion定义为交叉熵ce,因此此处的loss为output和target计算得到的交叉熵值。
loss = self.criterion(output, target)
# 反向传播,更新AutoDeeplab模型的参数。
if self.use_amp:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
# 执行一次优化步骤。
self.optimizer.step()
# 当epoch达到一定阙值时。
if epoch >= self.args.alpha_epoch:
search = next(iter(self.train_loaderB))
# 获取训练集B中一次采样得到的图像和label。
image_search, target_search = search['image'], search['label']
if self.args.cuda:
image_search, target_search = image_search.cuda (), target_search.cuda ()
# 将结构优化器中所有的梯度清零。
self.architect_optimizer.zero_grad()
# 获取当前模型在输入图像上运行的输出值。
output_search = self.model(image_search)
# 计算结构loss。
arch_loss = self.criterion(output_search, target_search)
# 反向传播,更新结构参数alpha和beta的值。
if self.use_amp:
with amp.scale_loss(arch_loss, self.architect_optimizer) as arch_scaled_loss:
arch_scaled_loss.backward()
else:
arch_loss.backward()
# 执行一次结构优化步骤。
self.architect_optimizer.step()
# 计算训练误差,显示在进度条上。
train_loss += loss.item()
tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
# Show 10 * 3 inference results each epoch
if i % (num_img_tr // 10) == 0:
global_step = i + num_img_tr * epoch
self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)
self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
print('Loss: %.3f' % train_loss)
# 这个参数不知道是干哈的。
if self.args.no_val:
# save checkpoint every epoch
is_best = False
if torch.cuda.device_count() > 1:
state_dict = self.model.module.state_dict()
else:
state_dict = self.model.state_dict()
self.saver.save_checkpoint({
'epoch': epoch + 1,
'state_dict': state_dict,
'optimizer': self.optimizer.state_dict(),
'best_pred': self.best_pred,
}, is_best)
def validation(self, epoch):
self.model.eval()
self.evaluator.reset()
tbar = tqdm(self.val_loader, desc='\r')
test_loss = 0.0
for i, sample in enumerate(tbar):
# 获取当前采样样本的图像和label值。
image, target = sample['image'], sample['label']
if self.args.cuda:
image, target = image.cuda(), target.cuda()
# 由于是验证集,不参与训练,因此要在torch.no_grad()的条件下进行计算。
with torch.no_grad():
output = self.model(image)
# 计算损失值loss。
loss = self.criterion(output, target)
test_loss += loss.item()
tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
# 将网络结构的输出转为pred预测值。
pred = output.data.cpu().numpy()
target = target.cpu().numpy()
pred = np.argmax(pred, axis=1)
# Add batch sample into evaluator
# 将这批采样结果放入evaluator中。
self.evaluator.add_batch(target, pred)
# Fast test during the training
# 计算Acc, Acc_class, mIoU, FWIoU四个评价指标的值。
Acc = self.evaluator.Pixel_Accuracy()
Acc_class = self.evaluator.Pixel_Accuracy_Class()
mIoU = self.evaluator.Mean_Intersection_over_Union()
FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
# 打印各项指标值。
self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
self.writer.add_scalar('val/mIoU', mIoU, epoch)
self.writer.add_scalar('val/Acc', Acc, epoch)
self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
print('Validation:')
print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
print('Loss: %.3f' % test_loss)
new_pred = mIoU
if new_pred > self.best_pred:
# 更新模型获得的最新预测。
is_best = True
self.best_pred = new_pred
if torch.cuda.device_count() > 1:
state_dict = self.model.module.state_dict()
else:
state_dict = self.model.state_dict()
self.saver.save_checkpoint({
'epoch': epoch + 1,
'state_dict': state_dict,
'optimizer': self.optimizer.state_dict(),
'best_pred': self.best_pred,
}, is_best)
def main():
args = obtain_search_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
try:
args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
except ValueError:
raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only')
if args.sync_bn is None:
if args.cuda and len(args.gpu_ids) > 1:
args.sync_bn = True
else:
args.sync_bn = False
# default settings for epochs, batch_size and lr
if args.epochs is None:
epoches = {
'coco': 30,
'cityscapes': 40,
'pascal': 50,
'kd':10
}
args.epochs = epoches[args.dataset.lower()]
if args.batch_size is None:
args.batch_size = 4 * len(args.gpu_ids)
if args.test_batch_size is None:
args.test_batch_size = args.batch_size
#args.lr = args.lr / (4 * len(args.gpu_ids)) * args.batch_size
if args.checkname is None:
args.checkname = 'deeplab-'+str(args.backbone)
print(args)
torch.manual_seed(args.seed)
# 这里初始化了待训练的整体结构。
trainer = Trainer(args)
print('Starting Epoch:', trainer.args.start_epoch)
print('Total Epoches:', trainer.args.epochs)
for epoch in range(trainer.args.start_epoch, trainer.args.epochs):
# 开始训练。
trainer.training(epoch)
if not trainer.args.no_val and epoch % args.eval_interval == (args.eval_interval - 1):
# 满足条件时,用验证集验证。
trainer.validation(epoch)
trainer.writer.close()