hqsam的训练

sam-hq/train/train.py->

train_datasets = 
valid_datasets = 

net = MaskDecoderHQ()->
- self.load_state_dict(torch.load(checkpoint))
- for n, p in self.named_parameters(): p.requires_grad = False # 不训练
- self.hf_token = nn.Embedding(1,256)
- self.hf_mlp = MLP(256,256,256//8,3)
- self.compress_vit_feat = 
- self.embedding_encoder = 
- self.embedding_maskfeature = 

main(net,train_datasets,valid_datasets,args)->

train_im_gt_list = get_im_gt_name_dict(train_datasets)
train_dataloaders,train_datasets = create_dataloader(train_im_gt_list,transforms,..)
- gos_dataset = [OnlineDataset(),...]
- gos_dataset = Concat(gos_dataset)
- gos_dataset = DistributedSampler(gos_dataset)
- batch_sampler_train = torch.utils.data.BatchSampler(sampler,bs,...)
- dataloader = DataLoader(gos_dataset,batch_sampler_train,...)

valid_dataloader,valid_datasets = create_dataloader()

net = torch.nn.parallel.DistributedParallel(net)

optimizer = optim.Adam()
lr_scheduler = torch.optim.lr_scheduler.StepLR()

train(args,net,optimizer,train_dataloader,valid_dataloader,lr_scheduler)->
- sam = sam_model_registry[args.model_type](checkpoint)
- sam = torch.nn.parallel.DistributedDataParallel(sam,...)
- for epoch in range(epoch_start,epoch_num):
    train_dataloader.batch_sampler.sampler.set_epoch(epoch)
    for data in metric_logger.log_every(train_dataloader,1000):
        inputs,label = data['image'],data['label']
        
        # input prompt
        input_keys = ['box','point','noise_mask']
        label_box = misc.masks_to_boxes[labels[:,0,:,:]]
        label_point = misc.masks_sample_points(labels[:,0,:,:])
        label_256 = F.interpolate(labels,(256,256),mode='bilinear')
        labels_noisemask = misc.masks_noise(labels_256)
        
        for b_i in range(len(imgs)):
            input_image = torch.as_tensor(imgs[b_i]).permute(2,0,1)
            dict_input['image'] = input_image
            input_type = random.choice(input_keys)
            if input_type = 'box':
                dict_input['boxes'] = labels_box[b_i:b_i+1]
            elif input_type = 'point':
                dict_input['point_coords'] = label_points[b_i:b_i+1]
            elif input_type = 'noise_mask':
                dict_input['mask_inputs'] = labels_noisemask[b_i:b_i+1]
            batched_input.append(dict_input)

        with torch.no_grad():
            batched_output,interm_embeddings = sam(batched_input)

        encoder_embedding = [batched_output['encoder_embeddings']...]
        image_pe = []
        sparse_embeddings = []
        dense_embeddings = []

        mask_hq = net(encoder_embedding,image_pe,sparse_embeddings,dense_embeddings,hq_token_only,interm_embeddings)
        
        loss_mask,loss_dice = loss_masks(mask_hq,labels/255,...)
        loss = loss_mask + loss_dice
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    lr_scheduler.step()

    torch.save()

训练启动:

torchrun --nproc_per_node= tools/train_hq_sam.py

sam的训练走的是Sam类,但是推理用的predict。 

你可能感兴趣的:(图像分割,机器学习,人工智能,深度学习)