pytorch读取tiny-imagenet-200的验证集(val)

 pytorch读取tiny-imagenet-200的验证集(val)_第1张图片

  ori_train = torchvision.datasets.ImageFolder(root= args.datadir + '/tiny-imagenet-200/train/', transform=transform)
  #可以获取class_idx的映射
  class_idx = ori_train.class_to_idx

val_annotations.txt中存储着每个图片对应的类别

获取验证集的标签

            test_target = []
            #读取val_annotations.txt
            test_data_dir = "./data/tiny-imagenet-200/val"
            with open(test_data_dir + "/val_annotations.txt", 'r') as file:
                # 读取每一行并存储在数组中
                lines = file.readlines()
            # 输出每一行的数据
            for line in lines:
                content = line.strip().split("\t")
                target = class_idx[content[1]]
                test_target.append(target)

 读取图片信息

ori_test_o = torchvision.datasets.ImageFolder(root= args.datadir + '/tiny-imagenet-200/val/', transform=transform)

自定义Dataset

ori_test = Imagenet_dataset(ori_test_o,test_target)


class Imagenet_dataset(torch.utils.data.Dataset):
    def __init__(self, dataset, targets):
        self.dataset = dataset
        self.targets = targets

    def __getitem__(self, idx):
        img = self.dataset[idx][0]
        label = self.targets[idx]
        return (img, label)

    def __len__(self):
        return len(self.dataset)

你可能感兴趣的:(python,pytorch,人工智能,python)