- 训练命令:
python newcrfs/train.py configs/arguments_train_nyu.txt
- main
if __name__ == '__main__':
main()
- main()
def main():
if args.mode != 'train':
...
command = 'mkdir ' + os.path.join(args.log_directory, args.model_name)
os.system(command)
args_out_path = os.path.join(args.log_directory, args.model_name)
command = 'cp ' + sys.argv[1] + ' ' + args_out_path
os.system(command)
save_files = True
if save_files:
aux_out_path = os.path.join(args.log_directory, args.model_name)
networks_savepath = os.path.join(aux_out_path, 'networks')
dataloaders_savepath = os.path.join(aux_out_path, 'dataloaders')
command = 'cp newcrfs/train.py ' + aux_out_path
os.system(command)
command = 'mkdir -p ' + networks_savepath + ' && cp newcrfs/networks/*.py ' + networks_savepath
os.system(command)
command = 'mkdir -p ' + dataloaders_savepath + ' && cp newcrfs/dataloaders/*.py ' + dataloaders_savepath
os.system(command)
torch.cuda.empty_cache()
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
ngpus_per_node = torch.cuda.device_count()
if ngpus_per_node > 1 and not args.multiprocessing_distributed:
print("This machine has more than 1 gpu. Please specify --multiprocessing_distributed, or set \'CUDA_VISIBLE_DEVICES=0\'")
return -1
if args.do_online_eval:
print("You have specified --do_online_eval.")
print("This will evaluate the model every eval_freq {} steps and save best models for individual eval metrics."
.format(args.eval_freq))
if args.multiprocessing_distributed:
args.world_size = ngpus_per_node * args.world_size
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
else:
main_worker(args.gpu, ngpus_per_node, args)
- main_worker(0, 1, args)
def main_worker(gpu, ngpus_per_node, args):
args.gpu = gpu
if args.gpu is not None: print("== Use GPU: {} for training".format(args.gpu))
if args.distributed:
...
model = NewCRFDepth(version=args.encoder, inv_depth=False, max_depth=args.max_depth, pretrained=args.pretrain)
model.train()
...
if args.distributed:
...
else:
model = torch.nn.DataParallel(model)
model.cuda()
print("== Model Initialized")
global_step = 0
best_eval_measures_lower_better = torch.zeros(6).cpu() + 1e3
best_eval_measures_higher_better = torch.zeros(3).cpu()
best_eval_steps = np.zeros(9, dtype=np.int32)
optimizer = torch.optim.Adam([{'params': model.module.parameters()}],
lr=args.learning_rate)
model_just_loaded = False
if args.checkpoint_path != '':
if os.path.isfile(args.checkpoint_path):
print("== Loading checkpoint '{}'".format(args.checkpoint_path))
if args.gpu is None:
checkpoint = torch.load(args.checkpoint_path)
else:
loc = 'cuda:{}'.format(args.gpu)
checkpoint = torch.load(args.checkpoint_path, map_location=loc)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
if not args.retrain:
try:
global_step = checkpoint['global_step']
best_eval_measures_higher_better = checkpoint['best_eval_measures_higher_better'].cpu()
best_eval_measures_lower_better = checkpoint['best_eval_measures_lower_better'].cpu()
best_eval_steps = checkpoint['best_eval_steps']
except KeyError:
print("Could not load values for online evaluation")
print("== Loaded checkpoint '{}' (global_step {})".format(args.checkpoint_path, checkpoint['global_step']))
else:
print("== No checkpoint found at '{}'".format(args.checkpoint_path))
model_just_loaded = True
del checkpoint
cudnn.benchmark = True
dataloader = NewDataLoader(args, 'train')
dataloader_eval = NewDataLoader(args, 'online_eval')
if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
writer = SummaryWriter(args.log_directory + '/' + args.model_name + '/summaries', flush_secs=30)
if args.do_online_eval:
if args.eval_summary_directory != '':
eval_summary_path = os.path.join(args.eval_summary_directory, args.model_name)
else:
eval_summary_path = os.path.join(args.log_directory, args.model_name, 'eval')
eval_summary_writer = SummaryWriter(eval_summary_path, flush_secs=30)
silog_criterion = silog_loss(variance_focus=args.variance_focus)
start_time = time.time()
duration = 0
num_log_images = args.batch_size
end_learning_rate = args.end_learning_rate if args.end_learning_rate != -1 else 0.1 * args.learning_rate
var_sum = [var.sum().item() for var in model.parameters() if var.requires_grad]
var_cnt = len(var_sum)
var_sum = np.sum(var_sum)
print("== Initial variables' sum: {:.3f}, avg: {:.3f}".format(var_sum, var_sum/var_cnt))
steps_per_epoch = len(dataloader.data)
num_total_steps = args.num_epochs * steps_per_epoch
epoch = global_step // steps_per_epoch
while epoch < args.num_epochs:
for step, sample_batched in enumerate(dataloader.data):
optimizer.zero_grad()
before_op_time = time.time()
image = torch.autograd.Variable(sample_batched['image'].cuda(args.gpu, non_blocking=True))
depth_gt = torch.autograd.Variable(sample_batched['depth'].cuda(args.gpu, non_blocking=True))
depth_est = model(image)
if args.dataset == 'nyu':
mask = depth_gt > 0.1
else:
mask = depth_gt > 1.0
loss = silog_criterion.forward(depth_est, depth_gt, mask.to(torch.bool))
loss.backward()
for param_group in optimizer.param_groups:
current_lr = (args.learning_rate - end_learning_rate) * (1 - global_step / num_total_steps) ** 0.9 + end_learning_rate
param_group['lr'] = current_lr
optimizer.step()
if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
print('[epoch][s/s_per_e/gs]: [{}][{}/{}/{}], lr: {:.12f}, loss: {:.12f}'.format(epoch, step, steps_per_epoch, global_step, current_lr, loss))
if np.isnan(loss.cpu().item()):
print('NaN in loss occurred. Aborting training.')
return -1
duration += time.time() - before_op_time
if global_step and global_step % args.log_freq == 0 and not model_just_loaded:
var_sum = [var.sum().item() for var in model.parameters() if var.requires_grad]
var_cnt = len(var_sum)
var_sum = np.sum(var_sum)
examples_per_sec = args.batch_size / duration * args.log_freq
duration = 0
time_sofar = (time.time() - start_time) / 3600
training_time_left = (num_total_steps / global_step - 1.0) * time_sofar
if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
print("{}".format(args.model_name))
print_string = 'GPU: {} | examples/s: {:4.2f} | loss: {:.5f} | var sum: {:.3f} avg: {:.3f} | time elapsed: {:.2f}h | time left: {:.2f}h'
print(print_string.format(args.gpu, examples_per_sec, loss, var_sum.item(), var_sum.item()/var_cnt, time_sofar, training_time_left))
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
and args.rank % ngpus_per_node == 0):
writer.add_scalar('silog_loss', loss, global_step)
writer.add_scalar('learning_rate', current_lr, global_step)
writer.add_scalar('var average', var_sum.item()/var_cnt, global_step)
depth_gt = torch.where(depth_gt < 1e-3, depth_gt * 0 + 1e3, depth_gt)
for i in range(num_log_images):
writer.add_image('depth_gt/image/{}'.format(i), normalize_result(1/depth_gt[i, :, :, :].data), global_step)
writer.add_image('depth_est/image/{}'.format(i), normalize_result(1/depth_est[i, :, :, :].data), global_step)
writer.add_image('image/image/{}'.format(i), inv_normalize(image[i, :, :, :]).data, global_step)
writer.flush()
if args.do_online_eval and global_step and global_step % args.eval_freq == 0 and not model_just_loaded:
time.sleep(0.1)
model.eval()
with torch.no_grad():
eval_measures = online_eval(model, dataloader_eval, gpu, ngpus_per_node, post_process=True)
if eval_measures is not None:
for i in range(9):
eval_summary_writer.add_scalar(eval_metrics[i], eval_measures[i].cpu(), int(global_step))
measure = eval_measures[i]
is_best = False
if i < 6 and measure < best_eval_measures_lower_better[i]:
old_best = best_eval_measures_lower_better[i].item()
best_eval_measures_lower_better[i] = measure.item()
is_best = True
elif i >= 6 and measure > best_eval_measures_higher_better[i-6]:
old_best = best_eval_measures_higher_better[i-6].item()
best_eval_measures_higher_better[i-6] = measure.item()
is_best = True
if is_best:
old_best_step = best_eval_steps[i]
old_best_name = '/model-{}-best_{}_{:.5f}'.format(old_best_step, eval_metrics[i], old_best)
model_path = args.log_directory + '/' + args.model_name + old_best_name
if os.path.exists(model_path):
command = 'rm {}'.format(model_path)
os.system(command)
best_eval_steps[i] = global_step
model_save_name = '/model-{}-best_{}_{:.5f}'.format(global_step, eval_metrics[i], measure)
print('New best for {}. Saving model: {}'.format(eval_metrics[i], model_save_name))
checkpoint = {'global_step': global_step,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'best_eval_measures_higher_better': best_eval_measures_higher_better,
'best_eval_measures_lower_better': best_eval_measures_lower_better,
'best_eval_steps': best_eval_steps
}
torch.save(checkpoint, args.log_directory + '/' + args.model_name + model_save_name)
eval_summary_writer.flush()
model.train()
block_print()
enable_print()
model_just_loaded = False
global_step += 1
epoch += 1
if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
writer.close()
if args.do_online_eval:
eval_summary_writer.close()
- model = NewCRFDepth(version=large07, inv_depth=False, max_depth=10, pretrained=swin_large_patch4_window7_224_22k.pth)
class NewCRFDepth(nn.Module):
"""
Depth network based on neural window FC-CRFs architecture.
"""
def __init__(self, version=large07, inv_depth=False, pretrained=swin_large_patch4_window7_224_22k.pth,
frozen_stages=-1, min_depth=0.1, max_depth=10.0, **kwargs):
super().__init__()
self.inv_depth = False
self.with_auxiliary_head = False
self.with_neck = False
norm_cfg = dict(type='BN', requires_grad=True)
window_size = int(version[-2:])
if version[:-2] == 'base':
...
elif version[:-2] == 'large':
embed_dim = 192
depths = [2, 2, 18, 2]
num_heads = [6, 12, 24, 48]
in_channels = [192, 384, 768, 1536]
elif version[:-2] == 'tiny':
...
backbone_cfg = dict(
embed_dim=embed_dim,
depths=depths,
num_heads=num_heads,
window_size=window_size,
ape=False,
drop_path_rate=0.3,
patch_norm=True,
use_checkpoint=False,
frozen_stages=frozen_stages
)
embed_dim = 512
decoder_cfg = dict(
in_channels=in_channels,
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=embed_dim,
dropout_ratio=0.0,
num_classes=32,
norm_cfg=norm_cfg,
align_corners=False
)
self.backbone = SwinTransformer(**backbone_cfg)
v_dim = decoder_cfg['num_classes']*4
win = 7
crf_dims = [128, 256, 512, 1024]
v_dims = [64, 128, 256, 512]
self.crf3 = NewCRF(input_dim=1536, embed_dim=1024, window_size=7, v_dim=512, num_heads=32)
self.crf2 = NewCRF(input_dim=768, embed_dim=512, window_size=7, v_dim=256, num_heads=16)
self.crf1 = NewCRF(input_dim=384, embed_dim=256, window_size=7, v_dim=128, num_heads=8)
self.crf0 = NewCRF(input_dim=192, embed_dim=128, window_size=7, v_dim=64, num_heads=4)
self.decoder = PSP(**decoder_cfg)
self.disp_head1 = DispHead(input_dim=128)
self.up_mode = 'bilinear'
self.min_depth = 0.1
self.max_depth = 10
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
print(f'== Load encoder backbone from: {pretrained}')
self.backbone.init_weights(pretrained=pretrained)
self.decoder.init_weights()
def forward(self, imgs):
feats = self.backbone(imgs)
ppm_out = self.decoder(feats)
e3 = self.crf3(feats[3], ppm_out)
e3 = nn.PixelShuffle(2)(e3)
e2 = self.crf2(feats[2], e3)
e2 = nn.PixelShuffle(2)(e2)
e1 = self.crf1(feats[1], e2)
e1 = nn.PixelShuffle(2)(e1)
e0 = self.crf0(feats[0], e1)
d1 = self.disp_head1(e0, 4)
depth = d1 * self.max_depth
return