Pytorch——报错解决:RuntimeError: Output 0 of SelectBackward is a view and is being modified inplace.

我在做Semi-Supervised过程中,需要分别计算labeledunlabeled dataloss,但是在多卡分布式过程中,不允许出现同一个model前向两次的情况,会报错(RuntimeError: Expected to mark a variable ready only once. This error is caused by one),因此我们需要将labeledunlabeled data合并在一起前向一次model,然后再把结果拆分出来即可,实现思路如下:

  • 合并labeledunlabeled data
  • 输入模型前向传播得到预测结果;
  • 再将预测结果分为labeledunlabeled data两部分;
  • 分别计算loss即可。
    def foward_student_train(self, sup_data, unsup_data):
        '''forward student
        '''
        # 合并输入data
        student_data, sup_data_length = \
            self.combine_student_data(sup_data, unsup_data)
        # 前向传播
        img_feats, _ = self.student.extract_feat(
            points=None, 
            img=student_data['img_inputs'], 
            img_metas=student_data['img_metas'])    # bev_feats, None
        student_info = self.student.pts_bbox_head(img_feats)
        
        # 分开预测结果datat
        sup_info, unsup_info = self.split_student_data(
            student_info, sup_data_length)
        # 计算有监督部分的loss
        loss_inputs = [sup_data['gt_bboxes_3d'], 
                       sup_data['gt_labels_3d'], 
                       sup_info]
        sup_loss = self.student.pts_bbox_head.loss(*loss_inputs)
        
        return sup_loss, unsup_info
    
	# 合并输入data
	def combine_student_data(self, sup_data, unsup_data):
        '''combine sup and unsup data for student model
        '''
        assert isinstance(sup_data, dict) and \
            isinstance(unsup_data, dict)
            
        new_student_data = deepcopy(sup_data)
        keys = sup_data.keys()
        for key in keys:
            if key == 'img_inputs':
                new_student_data[key] = self.combine_imgs(
                    sup_data[key], unsup_data[key])
            else:
                new_student_data[key] = sup_data[key] + \
                                        unsup_data[key]
        return new_student_data, (len(sup_data['img_metas']), 
                                len(unsup_data['img_metas']))

	# 分开预测结果data
    def split_student_data(self, student_data, student_data_length):
        sup_data, unsup_data = [], []
        for idx, data in enumerate(student_data):   # len(list) = 6
            sup_data.append([{}])
            unsup_data.append([{}])
            # 默认len(list) = 1
            for key in data[0].keys():
                sup_data[idx][0][key] = data[0][key][:student_data_length[0], ...].clone()
                # detach upservised data gradient
                unsup_data[idx][0][key] = data[0][key][student_data_length[0]:, ...].clone()
        return tuple(sup_data), tuple(unsup_data)

你可能感兴趣的:(pytorch,pytorch,深度学习,python)