pytorch只用中要注意通道问题

cv读进来的是BGR图像,通道是hcw,在torch中使用要注意维度转换

def __getitem__(self, idx):
        '''Load image.

        Args:
          idx: (int) image index.
       img_org = Image.open(self.root_src  +  'reference_cutBlock' + fname_org)
        Returns:
          img: (tensor) image tensor.
          loc_targets: (tensor) location targets.
          cls_targets: (tensor) class label targets.
        '''
        # Load image
        fname_org = self.fnames[idx]
        img_org = cv2.imread(self.root_src + 'dn_dataset/' + fname_org)
        # img_org = np.asarray(img_org)

        coin = np.random.randint(0, 50)
        img_dis = skimage.util.random_noise(img_org, mode='gaussian', seed=None,
                                            var=(coin / 255.0) ** 2)  # add  gaussian noise

        # img_dis = img_dis[:, :, (2, 1, 0)]  # bgr012 to rgb210
        img_dis = img_dis.transpose([2, 0, 1])  # hwc to chw
        img_dis = img_dis[(2, 1, 0), :, :]  # bgr012 to rgb210

        img_org = img_org[:, :, (2, 1, 0)]/255.0  # bgr012 to rgb210
        img_org = img_org.transpose([2, 0, 1])  # hwc to chw

        img_dis = torch.from_numpy(img_dis).float()
        img_org = torch.from_numpy(img_org).float()
        # fname_org_dis = self.fnames_dis[idx]
        # img_dis = Image.open(self.root_src  +  'distorted_train_block/' + fname_org_dis)

        # if img_org.mode != 'RGB':
        #     img_org = img_org.convert('RGB')
        #
        # if img_dis.mode != 'RGB':
        #     img_dis = img_dis.convert('RGB')
        # img_org = self.transform(img_org)
        # img_dis = self.transform(img_dis)

        return img_dis, img_org

transforms.ToTensor() 有两层含义,一个是转化成Tensor,另一个是进行归一化,此段代码,没有采用此语句,而是分两步完成,因为img_dis,已经实现归一化。

你可能感兴趣的:(pytorch学习)