sam-hq/train/train.py->
train_datasets =
valid_datasets =
net = MaskDecoderHQ()->
- self.load_state_dict(torch.load(checkpoint))
- for n, p in self.named_parameters(): p.requires_grad = False # 不训练
- self.hf_token = nn.Embedding(1,256)
- self.hf_mlp = MLP(256,256,256//8,3)
- self.compress_vit_feat =
- self.embedding_encoder =
- self.embedding_maskfeature =
main(net,train_datasets,valid_datasets,args)->
train_im_gt_list = get_im_gt_name_dict(train_datasets)
train_dataloaders,train_datasets = create_dataloader(train_im_gt_list,transforms,..)
- gos_dataset = [OnlineDataset(),...]
- gos_dataset = Concat(gos_dataset)
- gos_dataset = DistributedSampler(gos_dataset)
- batch_sampler_train = torch.utils.data.BatchSampler(sampler,bs,...)
- dataloader = DataLoader(gos_dataset,batch_sampler_train,...)
valid_dataloader,valid_datasets = create_dataloader()
net = torch.nn.parallel.DistributedParallel(net)
optimizer = optim.Adam()
lr_scheduler = torch.optim.lr_scheduler.StepLR()
train(args,net,optimizer,train_dataloader,valid_dataloader,lr_scheduler)->
- sam = sam_model_registry[args.model_type](checkpoint)
- sam = torch.nn.parallel.DistributedDataParallel(sam,...)
- for epoch in range(epoch_start,epoch_num):
train_dataloader.batch_sampler.sampler.set_epoch(epoch)
for data in metric_logger.log_every(train_dataloader,1000):
inputs,label = data['image'],data['label']
# input prompt
input_keys = ['box','point','noise_mask']
label_box = misc.masks_to_boxes[labels[:,0,:,:]]
label_point = misc.masks_sample_points(labels[:,0,:,:])
label_256 = F.interpolate(labels,(256,256),mode='bilinear')
labels_noisemask = misc.masks_noise(labels_256)
for b_i in range(len(imgs)):
input_image = torch.as_tensor(imgs[b_i]).permute(2,0,1)
dict_input['image'] = input_image
input_type = random.choice(input_keys)
if input_type = 'box':
dict_input['boxes'] = labels_box[b_i:b_i+1]
elif input_type = 'point':
dict_input['point_coords'] = label_points[b_i:b_i+1]
elif input_type = 'noise_mask':
dict_input['mask_inputs'] = labels_noisemask[b_i:b_i+1]
batched_input.append(dict_input)
with torch.no_grad():
batched_output,interm_embeddings = sam(batched_input)
encoder_embedding = [batched_output['encoder_embeddings']...]
image_pe = []
sparse_embeddings = []
dense_embeddings = []
mask_hq = net(encoder_embedding,image_pe,sparse_embeddings,dense_embeddings,hq_token_only,interm_embeddings)
loss_mask,loss_dice = loss_masks(mask_hq,labels/255,...)
loss = loss_mask + loss_dice
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
torch.save()
训练启动:
torchrun --nproc_per_node=tools/train_hq_sam.py
sam的训练走的是Sam类,但是推理用的predict。